From 784c3c1f5e2993209a034e870bdd3bb9739aa366 Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Thu, 27 Jun 2024 15:22:36 +0200 Subject: [PATCH 01/19] Parsing LLM Response for Kotlin tests --- build.gradle.kts | 29 ++++ .../generation/llm/LLMWithFeedbackCycle.kt | 21 ++- .../testspark/core/generation/llm/Utils.kt | 3 + .../generation/llm/network/RequestManager.kt | 12 +- .../testspark/core/test/TestsAssembler.kt | 3 +- .../core/test/parsers/TestSuiteParser.kt | 7 + .../parsers/java/JavaJUnitTestSuiteParser.kt | 22 +++ .../kotlin/KotlinJUnitTestSuiteParser.kt | 22 +++ .../JUnitTestSuiteParserStrategy.kt} | 135 +++++++++--------- .../research/testspark/core/utils/Language.kt | 8 ++ .../research/testspark/core/utils/Patterns.kt | 16 ++- .../kotlin/KotlinJUnitTestSuiteParserTest.kt | 125 ++++++++++++++++ .../testspark/java/JavaPsiClassWrapper.kt | 8 +- .../resources/META-INF/testspark-java.xml | 9 ++ .../testspark/kotlin/KotlinPsiClassWrapper.kt | 8 +- .../testspark/kotlin/KotlinPsiHelper.kt | 2 +- .../resources/META-INF/testspark-kotlin.xml | 8 ++ .../testspark/langwrappers/PsiComponents.kt | 5 +- .../testspark/actions/TestSparkAction.kt | 2 +- .../actions/llm/LLMSampleSelectorFactory.kt | 2 +- .../testspark/display/TestCasePanelFactory.kt | 4 +- .../helpers/JavaClassBuilderHelper.kt | 2 +- .../research/testspark/helpers/LLMHelper.kt | 3 + .../services/TestCaseDisplayService.kt | 5 +- .../research/testspark/tools/Pipeline.kt | 8 +- .../evosuite/EvoSuiteSettingsArguments.kt | 2 +- .../research/testspark/tools/llm/Llm.kt | 4 + .../tools/llm/LlmSettingsArguments.kt | 2 +- .../llm/generation/JUnitTestsAssembler.kt | 20 ++- .../tools/llm/generation/LLMProcessManager.kt | 4 + .../tools/llm/generation/PromptManager.kt | 2 +- src/main/resources/META-INF/plugin.xml | 2 +- .../SettingsArgumentsLlmEvoSuiteTest.kt | 2 +- 33 files changed, 394 insertions(+), 113 deletions(-) create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt rename core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/{java/JUnitTestSuiteParser.kt => strategies/JUnitTestSuiteParserStrategy.kt} (52%) create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt create mode 100644 core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt create mode 100644 java/src/main/resources/META-INF/testspark-java.xml create mode 100644 kotlin/src/main/resources/META-INF/testspark-kotlin.xml diff --git a/build.gradle.kts b/build.gradle.kts index 13da233c4..b83f7e6bd 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,5 +1,6 @@ import org.jetbrains.changelog.markdownToHTML import org.jetbrains.intellij.tasks.RunIdeTask +import org.jetbrains.intellij.tasks.RunPluginVerifierTask import org.jetbrains.kotlin.gradle.tasks.KotlinCompile import java.io.FileOutputStream import java.net.URL @@ -208,6 +209,15 @@ tasks { dependsOn(":core:compileKotlin") } + verifyPlugin { + dependsOn(":copyPluginAssets") + onlyIf { this.project == rootProject } + } + + runIde { + onlyIf { this.project == rootProject } + } + // Set the JVM compatibility versions properties("javaVersion").let { withType { @@ -286,6 +296,25 @@ tasks { // https://plugins.jetbrains.com/docs/intellij/deployment.html#specifying-a-release-channel channels.set(listOf(properties("pluginVersion").split('-').getOrElse(1) { "default" }.split('.').first())) } + + withType { + onlyIf { this.project == rootProject } + mustRunAfter("check") + + // 1.365 is broken, +// remove this version as soon as https://youtrack.jetbrains.com/issue/MP-6438 is fixed. +// verifierVersion.set("1.364") + ideVersions.set(properties("ideVersionVerifier").split(",")) + failureLevel.set( + listOf( + RunPluginVerifierTask.FailureLevel.INTERNAL_API_USAGES, + RunPluginVerifierTask.FailureLevel.COMPATIBILITY_PROBLEMS, + RunPluginVerifierTask.FailureLevel.OVERRIDE_ONLY_API_USAGES, + RunPluginVerifierTask.FailureLevel.NON_EXTENDABLE_API_USAGES, + RunPluginVerifierTask.FailureLevel.PLUGIN_STRUCTURE_WARNINGS, + ) + ) + } } abstract class CopyJUnitRunnerLib : DefaultTask() { diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt index c0bb34ff2..0c8a428aa 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt @@ -16,6 +16,7 @@ import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.test.TestsPresenter import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.utils.Language import java.io.File enum class FeedbackCycleExecutionResult { @@ -44,6 +45,7 @@ data class FeedbackResponse( class LLMWithFeedbackCycle( private val report: Report, + private val language: Language, private val initialPromptMessage: String, private val promptSizeReductionStrategy: PromptSizeReductionStrategy, // filename in which the test suite is saved in result path @@ -99,6 +101,7 @@ class LLMWithFeedbackCycle( // clearing test assembler's collected text on the previous attempts testsAssembler.clear() val response: LLMResponse = requestManager.request( + language = language, prompt = nextPromptMessage, indicator = indicator, packageName = packageName, @@ -119,6 +122,7 @@ class LLMWithFeedbackCycle( continue } } + ResponseErrorCode.PROMPT_TOO_LONG -> { if (promptSizeReductionStrategy.isReductionPossible()) { nextPromptMessage = promptSizeReductionStrategy.reduceSizeAndGeneratePrompt() @@ -132,11 +136,13 @@ class LLMWithFeedbackCycle( break } } + ResponseErrorCode.EMPTY_LLM_RESPONSE -> { nextPromptMessage = "You have provided an empty answer! Please, answer my previous question with the same formats" continue } + ResponseErrorCode.TEST_SUITE_PARSING_FAILURE -> { onWarningCallback?.invoke(WarningType.TEST_SUITE_PARSING_FAILED) log.info { "Cannot parse a test suite from the LLM response. LLM response: '$response'" } @@ -161,7 +167,8 @@ class LLMWithFeedbackCycle( generatedTestSuite.updateTestCases(compilableTestCases.toMutableList()) } else { for (testCaseIndex in generatedTestSuite.testCases.indices) { - val testCaseFilename = "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" + val testCaseFilename = + "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" val testCaseRepresentation = testsPresenter.representTestCase(generatedTestSuite, testCaseIndex) @@ -205,8 +212,10 @@ class LLMWithFeedbackCycle( // Compile the test file indicator.setText("Compilation tests checking") - val testCasesCompilationResult = testCompiler.compileTestCases(generatedTestCasesPaths, buildPath, testCases) - val testSuiteCompilationResult = testCompiler.compileCode(File(generatedTestSuitePath).absolutePath, buildPath) + val testCasesCompilationResult = + testCompiler.compileTestCases(generatedTestCasesPaths, buildPath, testCases) + val testSuiteCompilationResult = + testCompiler.compileCode(File(generatedTestSuitePath).absolutePath, buildPath) // saving the compilable test cases compilableTestCases.addAll(testCasesCompilationResult.compilableTestCases) @@ -216,7 +225,8 @@ class LLMWithFeedbackCycle( onWarningCallback?.invoke(WarningType.COMPILATION_ERROR_OCCURRED) - nextPromptMessage = "I cannot compile the tests that you provided. The error is:\n${testSuiteCompilationResult.second}\n Fix this issue in the provided tests.\nGenerate public classes and public methods. Response only a code with tests between ```, do not provide any other text." + nextPromptMessage = + "I cannot compile the tests that you provided. The error is:\n${testSuiteCompilationResult.second}\n Fix this issue in the provided tests.\nGenerate public classes and public methods. Response only a code with tests between ```, do not provide any other text." log.info { nextPromptMessage } continue } @@ -226,7 +236,8 @@ class LLMWithFeedbackCycle( generatedTestsArePassing = true for (index in testCases.indices) { - report.testCaseList[index] = TestCase(index, testCases[index].name, testCases[index].toString(), setOf()) + report.testCaseList[index] = + TestCase(index, testCases[index].name, testCases[index].toString(), setOf()) } } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt index 1f018a27a..76cb74c17 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt @@ -6,6 +6,7 @@ import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.utils.Language import java.util.Locale // TODO: find a better place for the below functions @@ -38,6 +39,7 @@ fun getClassWithTestCaseName(testCaseName: String): String { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun executeTestCaseModificationRequest( + language: Language, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -59,6 +61,7 @@ fun executeTestCaseModificationRequest( } val response = requestManager.request( + language, prompt, indicator, packageName, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt index 705db1d72..689eec798 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt @@ -8,6 +8,7 @@ import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.TestsAssembler +import org.jetbrains.research.testspark.core.utils.Language abstract class RequestManager(var token: String) { enum class SendResult { @@ -30,6 +31,7 @@ abstract class RequestManager(var token: String) { * @return the generated TestSuite, or null and prompt message */ open fun request( + language: Language, prompt: String, indicator: CustomProgressIndicator, packageName: String, @@ -55,14 +57,15 @@ abstract class RequestManager(var token: String) { } return when (isUserFeedback) { - true -> processUserFeedbackResponse(testsAssembler, packageName) - false -> processResponse(testsAssembler, packageName) + true -> processUserFeedbackResponse(testsAssembler, packageName, language) + false -> processResponse(testsAssembler, packageName, language) } } open fun processResponse( testsAssembler: TestsAssembler, packageName: String, + language: Language, ): LLMResponse { // save the full response in the chat history val response = testsAssembler.getContent() @@ -75,7 +78,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) @@ -94,6 +97,7 @@ abstract class RequestManager(var token: String) { open fun processUserFeedbackResponse( testsAssembler: TestsAssembler, packageName: String, + language: Language, ): LLMResponse { val response = testsAssembler.getContent() @@ -104,7 +108,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt index a761e53b2..6e5a4e127 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt @@ -1,6 +1,7 @@ package org.jetbrains.research.testspark.core.test import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.utils.Language abstract class TestsAssembler { private var rawText = "" @@ -37,5 +38,5 @@ abstract class TestsAssembler { * @param packageName The package name to be set in the generated TestSuite. * @return A TestSuiteGeneratedByLLM object containing the extracted test cases and package name. */ - abstract fun assembleTestSuite(packageName: String): TestSuiteGeneratedByLLM? + abstract fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt index 9ce724888..a0551ed7c 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt @@ -1,7 +1,14 @@ package org.jetbrains.research.testspark.core.test.parsers +import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +data class TestCaseParseResult( + val testCase: TestCaseGeneratedByLLM?, + val errorMessage: String, + val errorOccurred: Boolean, +) + interface TestSuiteParser { /** * Extracts test cases from raw text and generates a test suite using the given package name. diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt new file mode 100644 index 000000000..a8728bbf2 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt @@ -0,0 +1,22 @@ +package org.jetbrains.research.testspark.core.test.parsers.java + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser +import org.jetbrains.research.testspark.core.test.parsers.strategies.JUnitTestSuiteParserStrategy + +class JavaJUnitTestSuiteParser( + private val packageName: String, + private val junitVersion: JUnitVersion, + private val importPattern: Regex, +) : TestSuiteParser { + override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { + return JUnitTestSuiteParserStrategy.parseTestSuite( + rawText, + junitVersion, + importPattern, + packageName, + testNamePattern = "void", + ) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt new file mode 100644 index 000000000..09bdbc627 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt @@ -0,0 +1,22 @@ +package org.jetbrains.research.testspark.core.test.parsers.kotlin + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser +import org.jetbrains.research.testspark.core.test.parsers.strategies.JUnitTestSuiteParserStrategy + +class KotlinJUnitTestSuiteParser( + private val packageName: String, + private val junitVersion: JUnitVersion, + private val importPattern: Regex, +) : TestSuiteParser { + override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { + return JUnitTestSuiteParserStrategy.parseTestSuite( + rawText, + junitVersion, + importPattern, + packageName, + testNamePattern = "fun", + ) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt similarity index 52% rename from core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JUnitTestSuiteParser.kt rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt index 2186a61c7..98c6827c5 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JUnitTestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt @@ -1,94 +1,93 @@ -package org.jetbrains.research.testspark.core.test.parsers.java +package org.jetbrains.research.testspark.core.test.parsers.strategies import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestLine import org.jetbrains.research.testspark.core.test.data.TestLineType import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.utils.importPattern - -class JUnitTestSuiteParser( - private val packageName: String, - private val junitVersion: JUnitVersion, -) : TestSuiteParser { - override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { - if (rawText.isBlank()) { - return null - } - - try { - var rawCode = rawText - - if (rawText.contains("```")) { - rawCode = rawText.split("```")[1] +import org.jetbrains.research.testspark.core.test.parsers.TestCaseParseResult + +class JUnitTestSuiteParserStrategy { + companion object { + fun parseTestSuite( + rawText: String, + junitVersion: JUnitVersion, + importPattern: Regex, + packageName: String, + testNamePattern: String, + ): TestSuiteGeneratedByLLM? { + if (rawText.isBlank()) { + return null } - // save imports - val imports = importPattern.findAll(rawCode, 0) - .map { it.groupValues[0] } - .toSet() + try { + var rawCode = rawText - // save RunWith - val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" + if (rawText.contains("```")) { + rawCode = rawText.split("```")[1] + } - val testSet: MutableList = rawCode.split("@Test").toMutableList() + // save imports + val imports = importPattern.findAll(rawCode, 0) + .map { it.groupValues[0] } + .toSet() - // save annotations and pre-set methods - val otherInfo: String = run { - val otherInfoList = testSet.removeAt(0).split("{").toMutableList() - otherInfoList.removeFirst() - val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" - otherInfo.ifBlank { "" } - } + // save RunWith + val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" - // Save the main test cases - val testCases: MutableList = mutableListOf() - val testCaseParser = JUnitTestCaseParser() + val testSet: MutableList = rawCode.split("@Test").toMutableList() - testSet.forEach ca@{ - val rawTest = "@Test$it" + // save annotations and pre-set methods + val otherInfo: String = run { + val otherInfoList = testSet.removeAt(0).split("{").toMutableList() + otherInfoList.removeFirst() + val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" + otherInfo.ifBlank { "" } + } - val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) - val result: TestCaseParseResult = testCaseParser.parse(rawTest, isLastTestCaseInTestSuite) + // Save the main test cases + val testCases: MutableList = mutableListOf() + val testCaseParser = JUnitTestCaseParser() - if (result.errorOccurred) { - println("WARNING: ${result.errorMessage}") - return@ca - } + testSet.forEach ca@{ + val rawTest = "@Test$it" - val currentTest = result.testCase!! + val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) + val result: TestCaseParseResult = + testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern) // /// - // TODO: make logging work - // log.info("New test case: $currentTest") - println("New test case: $currentTest") + if (result.errorOccurred) { + println("WARNING: ${result.errorMessage}") + return@ca + } - testCases.add(currentTest) - } + val currentTest = result.testCase!! - val testSuite = TestSuiteGeneratedByLLM( - imports = imports, - packageString = packageName, - runWith = runWith, - otherInfo = otherInfo, - testCases = testCases, - ) + // TODO: make logging work + // log.info("New test case: $currentTest") + println("New test case: $currentTest") + + testCases.add(currentTest) + } - return testSuite - } catch (e: Exception) { - return null + val testSuite = TestSuiteGeneratedByLLM( + imports = imports, + packageString = packageName, + runWith = runWith, + otherInfo = otherInfo, + testCases = testCases, + ) + + return testSuite + } catch (e: Exception) { + return null + } } } } -private data class TestCaseParseResult( - val testCase: TestCaseGeneratedByLLM?, - val errorMessage: String, - val errorOccurred: Boolean, -) - private class JUnitTestCaseParser { - fun parse(rawTest: String, isLastTestCaseInTestSuite: Boolean): TestCaseParseResult { + fun parse(rawTest: String, isLastTestCaseInTestSuite: Boolean, testNamePattern: String): TestCaseParseResult { var expectedException = "" var throwsException = "" val testLines: MutableList = mutableListOf() @@ -99,10 +98,10 @@ private class JUnitTestCaseParser { } // Get unexpected exceptions - /* Each test case should follow [public] void {...} + /* Each test case should follow fun {...} Tests do not return anything so it is safe to consider that void always appears before test case name */ - val voidString = "void" + val voidString = testNamePattern if (!rawTest.contains(voidString)) { return TestCaseParseResult( testCase = null, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt new file mode 100644 index 000000000..250ec7cba --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt @@ -0,0 +1,8 @@ +package org.jetbrains.research.testspark.core.utils + +/** + * Language ID string should be the same as the language name in com.intellij.lang.Language + */ +enum class Language(val languageId: String) { + Java("JAVA"), Kotlin("Kotlin") +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt index 123610ac7..95903bf8c 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt @@ -1,13 +1,25 @@ package org.jetbrains.research.testspark.core.utils -val importPattern = +val javaImportPattern = Regex( pattern = "^import\\s+(static\\s)?((?:[a-zA-Z_]\\w*\\.)*[a-zA-Z_](?:\\w*\\.?)*)(?:\\.\\*)?;", options = setOf(RegexOption.MULTILINE), ) -val packagePattern = +val kotlinImportPattern = + Regex( + pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?", + options = setOf(RegexOption.MULTILINE), + ) + +val javaPackagePattern = Regex( pattern = "^package\\s+((?:[a-zA-Z_]\\w*\\.)*[a-zA-Z_](?:\\w*\\.?)*)(?:\\.\\*)?;", options = setOf(RegexOption.MULTILINE), ) + +val kotlinPackagePattern = + Regex( + pattern = "^package\\s+((?:[a-zA-Z_]\\w*\\.)*[a-zA-Z_](?:\\w*\\.?)*)(?:\\.\\*)?", + options = setOf(RegexOption.MULTILINE), + ) diff --git a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt new file mode 100644 index 000000000..2ebcde0c9 --- /dev/null +++ b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt @@ -0,0 +1,125 @@ +package org.jetbrains.research.testspark.core.test.parsers.kotlin + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.utils.kotlinImportPattern +import org.junit.jupiter.api.Test +import kotlin.test.assertNotNull + +class KotlinJUnitTestSuiteParserTest { + + @Test + fun testFunction() { + val text = """ + ```kotlin + import org.junit.jupiter.api.Assertions.* + import org.junit.jupiter.api.Test + import org.mockito.Mockito.* + import org.mockito.kotlin.any + import org.mockito.kotlin.eq + import org.mockito.kotlin.mock + import org.test.Message as TestMessage + + class MyClassTest { + + @Test + fun compileTestCases_AllCompilableTest() { + // Arrange + val myClass = MyClass() + val generatedTestCasesPaths = listOf("path1", "path2") + val buildPath = "buildPath" + val testCase1 = TestCaseGeneratedByLLM() + val testCase2 = TestCaseGeneratedByLLM() + val testCases = mutableListOf(testCase1, testCase2) + + val myClassSpy = spy(myClass) + doReturn(Pair(true, "")).`when`(myClassSpy).compileCode(any(), eq(buildPath)) + + // Act + val result = myClassSpy.compileTestCases(generatedTestCasesPaths, buildPath, testCases) + + // Assert + assertTrue(result.allTestCasesCompilable) + assertEquals(setOf(testCase1, testCase2), result.compilableTestCases) + } + + @Test + fun compileTestCases_NoneCompilableTest() { + // Arrange + val myClass = MyClass() + val generatedTestCasesPaths = listOf("path1", "path2") + val buildPath = "buildPath" + val testCase1 = TestCaseGeneratedByLLM() + val testCase2 = TestCaseGeneratedByLLM() + val testCases = mutableListOf(testCase1, testCase2) + + val myClassSpy = spy(myClass) + doReturn(Pair(false, "")).`when`(myClassSpy).compileCode(any(), eq(buildPath)) + + // Act + val result = myClassSpy.compileTestCases(generatedTestCasesPaths, buildPath, testCases) + + // Assert + assertFalse(result.allTestCasesCompilable) + assertTrue(result.compilableTestCases.isEmpty()) + } + + @Test + fun compileTestCases_SomeCompilableTest() { + // Arrange + val myClass = MyClass() + val generatedTestCasesPaths = listOf("path1", "path2") + val buildPath = "buildPath" + val testCase1 = TestCaseGeneratedByLLM() + val testCase2 = TestCaseGeneratedByLLM() + val testCases = mutableListOf(testCase1, testCase2) + + val myClassSpy = spy(myClass) + doReturn(Pair(true, "")).`when`(myClassSpy).compileCode(eq("path1"), eq(buildPath)) + doReturn(Pair(false, "")).`when`(myClassSpy).compileCode(eq("path2"), eq(buildPath)) + + // Act + val result = myClassSpy.compileTestCases(generatedTestCasesPaths, buildPath, testCases) + + // Assert + assertFalse(result.allTestCasesCompilable) + assertEquals(setOf(testCase1), result.compilableTestCases) + } + + @Test + fun compileTestCases_EmptyTestCasesTest() { + // Arrange + val myClass = MyClass() + val generatedTestCasesPaths = emptyList() + val buildPath = "buildPath" + val testCases = mutableListOf() + + // Act + val result = myClass.compileTestCases(generatedTestCasesPaths, buildPath, testCases) + + // Assert + assertTrue(result.allTestCasesCompilable) + assertTrue(result.compilableTestCases.isEmpty()) + } + + @Test(expected = ArithmeticException::class, Exception::class) + fun compileTestCases_omg() { + val blackHole = 1 / 0 + } + } + ``` + """.trimIndent() + val parser = KotlinJUnitTestSuiteParser("org.my.package", JUnitVersion.JUnit5, kotlinImportPattern) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assert(testSuite.imports.contains("import org.mockito.Mockito.*")) + assert(testSuite.imports.contains("import org.test.Message as TestMessage")) + assert(testSuite.imports.contains("import org.mockito.kotlin.mock")) + assert(testSuite.testCases[0].name == "compileTestCases_AllCompilableTest") + assert(testSuite.testCases[1].name == "compileTestCases_NoneCompilableTest") + assert(testSuite.testCases[2].name == "compileTestCases_SomeCompilableTest") + assert(testSuite.testCases[3].name == "compileTestCases_EmptyTestCasesTest") + assert(testSuite.testCases[4].name == "compileTestCases_omg") + assert(testSuite.testCases[4].expectedException.isNotBlank()) + } +} diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt index 725451c89..007bdbff7 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt @@ -10,8 +10,8 @@ import com.intellij.psi.search.GlobalSearchScope import com.intellij.psi.search.searches.ClassInheritorsSearch import com.intellij.psi.util.PsiTypesUtil import org.jetbrains.research.testspark.core.data.ClassType -import org.jetbrains.research.testspark.core.utils.importPattern -import org.jetbrains.research.testspark.core.utils.packagePattern +import org.jetbrains.research.testspark.core.utils.javaImportPattern +import org.jetbrains.research.testspark.core.utils.javaPackagePattern import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper @@ -38,14 +38,14 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { val fileText = psiClass.containingFile.text // get package - packagePattern.findAll(fileText).map { + javaPackagePattern.findAll(fileText).map { it.groupValues[0] }.forEach { fullText += "$it\n\n" } // get imports - importPattern.findAll(fileText).map { + javaImportPattern.findAll(fileText).map { it.groupValues[0] }.forEach { fullText += "$it\n" diff --git a/java/src/main/resources/META-INF/testspark-java.xml b/java/src/main/resources/META-INF/testspark-java.xml new file mode 100644 index 000000000..180580ca7 --- /dev/null +++ b/java/src/main/resources/META-INF/testspark-java.xml @@ -0,0 +1,9 @@ + + + + + + diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt index 829e7698d..8ac75755c 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt @@ -17,8 +17,8 @@ import org.jetbrains.kotlin.psi.KtObjectDeclaration import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils import org.jetbrains.research.testspark.core.data.ClassType -import org.jetbrains.research.testspark.core.utils.importPattern -import org.jetbrains.research.testspark.core.utils.packagePattern +import org.jetbrains.research.testspark.core.utils.kotlinImportPattern +import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper @@ -66,14 +66,14 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra val fileText = psiClass.containingFile.text // get package - packagePattern.findAll(fileText, 0).map { + kotlinPackagePattern.findAll(fileText, 0).map { it.groupValues[0] }.forEach { fullText += "$it\n\n" } // get imports - importPattern.findAll(fileText, 0).map { + kotlinImportPattern.findAll(fileText, 0).map { it.groupValues[0] }.forEach { fullText += "$it\n" diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt index 076427fd1..13749bd35 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt @@ -26,7 +26,7 @@ import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper -class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { +class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { override val language: Language get() = Language.Kotlin diff --git a/kotlin/src/main/resources/META-INF/testspark-kotlin.xml b/kotlin/src/main/resources/META-INF/testspark-kotlin.xml new file mode 100644 index 000000000..22e5e05c8 --- /dev/null +++ b/kotlin/src/main/resources/META-INF/testspark-kotlin.xml @@ -0,0 +1,8 @@ + + + + + + \ No newline at end of file diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt index 92a600ba6..f61dc7a1b 100644 --- a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt @@ -5,10 +5,7 @@ import com.intellij.openapi.project.Project import com.intellij.openapi.vfs.VirtualFile import com.intellij.psi.PsiFile import org.jetbrains.research.testspark.core.data.ClassType - -enum class Language(val languageName: String) { - Java("Java"), Kotlin("Kotlin") -} +import org.jetbrains.research.testspark.core.utils.Language /** * Interface representing a wrapper for PSI methods, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt index 5a0a96fbc..3b08ca009 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt @@ -393,4 +393,4 @@ class TestSparkAction : AnAction() { } override fun getActionUpdateThread(): ActionUpdateThread = ActionUpdateThread.BGT -} +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt index b57ee8d81..bb2c5a53f 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt @@ -4,7 +4,7 @@ import com.intellij.openapi.project.Project import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle -import org.jetbrains.research.testspark.helpers.LLMTestSampleHelper +import org.jetbrains.research.testspark.java.LLMTestSampleHelper import java.awt.Font import javax.swing.ButtonGroup import javax.swing.JButton diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt index 341dc95fd..f17e8720b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt @@ -58,6 +58,7 @@ import javax.swing.border.MatteBorder class TestCasePanelFactory( private val project: Project, + private val language: org.jetbrains.research.testspark.core.utils.Language, private val testCase: TestCase, editor: Editor, private val checkbox: JCheckBox, @@ -90,7 +91,7 @@ class TestCasePanelFactory( // Add an editor to modify the test source code private val languageTextField = LanguageTextField( - Language.findLanguageByID("JAVA"), + Language.findLanguageByID(language.languageId), editor.project, testCase.testCode, TestCaseDocumentCreator( @@ -408,6 +409,7 @@ class TestCasePanelFactory( } val modifiedTest = LLMHelper.testModificationRequest( + language, initialCodes[currentRequestNumber - 1], requestComboBox.editor.item.toString(), ijIndicator, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt index 977873bdb..cf62202b3 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt @@ -201,4 +201,4 @@ object JavaClassBuilderHelper { return result } -} +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt index 6c08c4653..b36fe381a 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt @@ -13,6 +13,7 @@ import org.jetbrains.research.testspark.core.generation.llm.network.RequestManag import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.utils.Language import org.jetbrains.research.testspark.settings.llm.LLMSettingsState import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager @@ -229,6 +230,7 @@ object LLMHelper { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun testModificationRequest( + language: Language, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -243,6 +245,7 @@ object LLMHelper { } val testSuite = executeTestCaseModificationRequest( + language, testCase, task, indicator, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt index 337cee754..e3b11555a 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt @@ -31,6 +31,7 @@ import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle import org.jetbrains.research.testspark.core.data.Report import org.jetbrains.research.testspark.core.data.TestCase +import org.jetbrains.research.testspark.core.utils.Language import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.display.TestCasePanelFactory import org.jetbrains.research.testspark.display.TopButtonsPanelFactory @@ -109,7 +110,7 @@ class TestCaseDisplayService(private val project: Project) { * Fill the panel with the generated test cases. Remove all previously shown test cases. * Add Tests and their names to a List of pairs (used for highlighting) */ - fun displayTestCases(report: Report, uiContext: UIContext) { + fun displayTestCases(report: Report, uiContext: UIContext, language: Language) { this.report = report this.uiContext = uiContext @@ -145,7 +146,7 @@ class TestCaseDisplayService(private val project: Project) { } testCasePanel.add(checkbox, BorderLayout.WEST) - val testCasePanelFactory = TestCasePanelFactory(project, testCase, editor, checkbox, uiContext, report) + val testCasePanelFactory = TestCasePanelFactory(project, language, testCase, editor, checkbox, uiContext, report) testCasePanel.add(testCasePanelFactory.getUpperPanel(), BorderLayout.NORTH) testCasePanel.add(testCasePanelFactory.getMiddlePanel(), BorderLayout.CENTER) testCasePanel.add(testCasePanelFactory.getBottomPanel(), BorderLayout.SOUTH) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt index 5efe4461b..aa5b694b7 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt @@ -36,9 +36,9 @@ import java.util.UUID */ class Pipeline( private val project: Project, - psiHelper: PsiHelper, - caretOffset: Int, - fileUrl: String?, + private val psiHelper: PsiHelper, + private val caretOffset: Int, + private val fileUrl: String?, private val packageName: String, private val testGenerationController: TestGenerationController, ) { @@ -114,7 +114,7 @@ class Pipeline( if (project.service().editor != null) { val report = it.testGenerationOutput.testGenerationResultList[0]!! - project.service().displayTestCases(report, it) + project.service().displayTestCases(report, it, psiHelper.language) project.service().showCoverage(report) } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt index 18a593ca2..20b47872c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt @@ -175,4 +175,4 @@ class EvoSuiteSettingsArguments( return if (command == "-Dcriterion=") "-Dcriterion=LINE" else command } } -} +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt index 1bc2ddb5a..01f16176c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt @@ -50,6 +50,7 @@ class Llm(override val name: String = "LLM") : Tool { return LLMProcessManager( project, + psiHelper.language, PromptManager(project, psiHelper, caretOffset), testSamplesCode, projectSDKPath, @@ -81,6 +82,7 @@ class Llm(override val name: String = "LLM") : Tool { createLLMPipeline(project, psiHelper, caretOffset, fileUrl, testGenerationController).runTestGeneration( LLMProcessManager( project, + psiHelper.language, PromptManager(project, psiHelper, caretOffset), testSamplesCode, ), @@ -114,6 +116,7 @@ class Llm(override val name: String = "LLM") : Tool { createLLMPipeline(project, psiHelper, caretOffset, fileUrl, testGenerationController).runTestGeneration( LLMProcessManager( project, + psiHelper.language, PromptManager(project, psiHelper, caretOffset), testSamplesCode, ), @@ -147,6 +150,7 @@ class Llm(override val name: String = "LLM") : Tool { createLLMPipeline(project, psiHelper, caretOffset, fileUrl, testGenerationController).runTestGeneration( LLMProcessManager( project, + psiHelper.language, PromptManager(project, psiHelper, caretOffset), testSamplesCode, ), diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt index 437ecd679..921980ed6 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt @@ -70,4 +70,4 @@ class LlmSettingsArguments(private val project: Project) { llmSettingsState.grazieName -> llmSettingsState.grazieModel else -> "" } -} +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt index 77cc2646c..e1bcb67ec 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt @@ -9,7 +9,10 @@ import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.java.JUnitTestSuiteParser +import org.jetbrains.research.testspark.core.test.parsers.java.JavaJUnitTestSuiteParser +import org.jetbrains.research.testspark.core.test.parsers.kotlin.KotlinJUnitTestSuiteParser +import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.utils.javaImportPattern import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState @@ -55,10 +58,10 @@ class JUnitTestsAssembler( } } - override fun assembleTestSuite(packageName: String): TestSuiteGeneratedByLLM? { + override fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? { val junitVersion = llmSettingsState.junitVersion - val parser = createTestSuiteParser(packageName, junitVersion) + val parser = createTestSuiteParser(packageName, junitVersion, language) val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(super.getContent()) // save RunWith @@ -78,7 +81,14 @@ class JUnitTestsAssembler( return testSuite } - private fun createTestSuiteParser(packageName: String, jUnitVersion: JUnitVersion): TestSuiteParser { - return JUnitTestSuiteParser(packageName, jUnitVersion) + private fun createTestSuiteParser( + packageName: String, + jUnitVersion: JUnitVersion, + language: Language, + ): TestSuiteParser { + return when (language) { + Language.Java -> JavaJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern) + Language.Kotlin -> KotlinJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern) + } } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt index 6ed499b73..bb1dee0ff 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt @@ -13,6 +13,7 @@ import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.TestsPresenter import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.utils.Language import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.IJReport import org.jetbrains.research.testspark.data.ProjectContext @@ -41,6 +42,7 @@ import java.nio.file.Path */ class LLMProcessManager( private val project: Project, + private val language: Language, private val promptManager: PromptManager, private val testSamplesCode: String, projectSDKPath: Path? = null, @@ -88,6 +90,7 @@ class LLMProcessManager( val report = IJReport() + // PROMPT GENERATION val initialPromptMessage = promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) val testCompiler = testProcessor.testCompiler @@ -125,6 +128,7 @@ class LLMProcessManager( // Asking LLM to generate a test suite. Here we have a feedback cycle for LLM in case of wrong responses val llmFeedbackCycle = LLMWithFeedbackCycle( + language = language, report = report, initialPromptMessage = initialPromptMessage, promptSizeReductionStrategy = promptSizeReductionStrategy, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt index 906873088..d7ac8f9f5 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt @@ -83,7 +83,7 @@ class PromptManager( classesToTest = classesToTest.map(this::createClassRepresentation).toList(), polymorphismRelations = polymorphismRelations, promptConfiguration = PromptConfiguration( - desiredLanguage = psiHelper.language.languageName, + desiredLanguage = psiHelper.language.languageId, desiredTestingPlatform = llmSettingsState.junitVersion.showName, desiredMockingFramework = "Mockito 5", ), diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index dffd2b46d..75aa7c65b 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -1,6 +1,6 @@ - org.jetbrains.research.testgenie + org.jetbrains.research.testspark TestSpark ictl diff --git a/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt index a05013d13..934a4fac5 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt @@ -215,4 +215,4 @@ class SettingsArgumentsLlmEvoSuiteTest { criterion, ).isEqualTo("-Dcriterion=LINE:BRANCH:EXCEPTION:WEAKMUTATION:OUTPUT:METHOD:METHODNOEXCEPTION:CBRANCH") } -} +} \ No newline at end of file From 6fb9f274b3de538272cba74ad32f93d7e4513586 Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Thu, 27 Jun 2024 15:27:39 +0200 Subject: [PATCH 02/19] last fixes of merge conflicts --- build.gradle.kts | 29 ------------------- .../resources/META-INF/testspark-java.xml | 9 ------ .../resources/META-INF/testspark-kotlin.xml | 8 ----- .../actions/llm/LLMSampleSelectorFactory.kt | 2 +- .../helpers/JavaClassBuilderHelper.kt | 2 +- .../evosuite/EvoSuiteSettingsArguments.kt | 2 +- .../tools/llm/LlmSettingsArguments.kt | 2 +- src/main/resources/META-INF/plugin.xml | 2 +- .../SettingsArgumentsLlmEvoSuiteTest.kt | 2 +- 9 files changed, 6 insertions(+), 52 deletions(-) delete mode 100644 java/src/main/resources/META-INF/testspark-java.xml delete mode 100644 kotlin/src/main/resources/META-INF/testspark-kotlin.xml diff --git a/build.gradle.kts b/build.gradle.kts index b83f7e6bd..13da233c4 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,6 +1,5 @@ import org.jetbrains.changelog.markdownToHTML import org.jetbrains.intellij.tasks.RunIdeTask -import org.jetbrains.intellij.tasks.RunPluginVerifierTask import org.jetbrains.kotlin.gradle.tasks.KotlinCompile import java.io.FileOutputStream import java.net.URL @@ -209,15 +208,6 @@ tasks { dependsOn(":core:compileKotlin") } - verifyPlugin { - dependsOn(":copyPluginAssets") - onlyIf { this.project == rootProject } - } - - runIde { - onlyIf { this.project == rootProject } - } - // Set the JVM compatibility versions properties("javaVersion").let { withType { @@ -296,25 +286,6 @@ tasks { // https://plugins.jetbrains.com/docs/intellij/deployment.html#specifying-a-release-channel channels.set(listOf(properties("pluginVersion").split('-').getOrElse(1) { "default" }.split('.').first())) } - - withType { - onlyIf { this.project == rootProject } - mustRunAfter("check") - - // 1.365 is broken, -// remove this version as soon as https://youtrack.jetbrains.com/issue/MP-6438 is fixed. -// verifierVersion.set("1.364") - ideVersions.set(properties("ideVersionVerifier").split(",")) - failureLevel.set( - listOf( - RunPluginVerifierTask.FailureLevel.INTERNAL_API_USAGES, - RunPluginVerifierTask.FailureLevel.COMPATIBILITY_PROBLEMS, - RunPluginVerifierTask.FailureLevel.OVERRIDE_ONLY_API_USAGES, - RunPluginVerifierTask.FailureLevel.NON_EXTENDABLE_API_USAGES, - RunPluginVerifierTask.FailureLevel.PLUGIN_STRUCTURE_WARNINGS, - ) - ) - } } abstract class CopyJUnitRunnerLib : DefaultTask() { diff --git a/java/src/main/resources/META-INF/testspark-java.xml b/java/src/main/resources/META-INF/testspark-java.xml deleted file mode 100644 index 180580ca7..000000000 --- a/java/src/main/resources/META-INF/testspark-java.xml +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - diff --git a/kotlin/src/main/resources/META-INF/testspark-kotlin.xml b/kotlin/src/main/resources/META-INF/testspark-kotlin.xml deleted file mode 100644 index 22e5e05c8..000000000 --- a/kotlin/src/main/resources/META-INF/testspark-kotlin.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt index bb2c5a53f..b57ee8d81 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt @@ -4,7 +4,7 @@ import com.intellij.openapi.project.Project import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle -import org.jetbrains.research.testspark.java.LLMTestSampleHelper +import org.jetbrains.research.testspark.helpers.LLMTestSampleHelper import java.awt.Font import javax.swing.ButtonGroup import javax.swing.JButton diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt index cf62202b3..977873bdb 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt @@ -201,4 +201,4 @@ object JavaClassBuilderHelper { return result } -} \ No newline at end of file +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt index 20b47872c..18a593ca2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt @@ -175,4 +175,4 @@ class EvoSuiteSettingsArguments( return if (command == "-Dcriterion=") "-Dcriterion=LINE" else command } } -} \ No newline at end of file +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt index 921980ed6..437ecd679 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt @@ -70,4 +70,4 @@ class LlmSettingsArguments(private val project: Project) { llmSettingsState.grazieName -> llmSettingsState.grazieModel else -> "" } -} \ No newline at end of file +} diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index 75aa7c65b..dffd2b46d 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -1,6 +1,6 @@ - org.jetbrains.research.testspark + org.jetbrains.research.testgenie TestSpark ictl diff --git a/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt index 934a4fac5..a05013d13 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt @@ -215,4 +215,4 @@ class SettingsArgumentsLlmEvoSuiteTest { criterion, ).isEqualTo("-Dcriterion=LINE:BRANCH:EXCEPTION:WEAKMUTATION:OUTPUT:METHOD:METHODNOEXCEPTION:CBRANCH") } -} \ No newline at end of file +} From 9756d76455bfe63d4e2446048b9cf3a014c143f3 Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Thu, 27 Jun 2024 15:28:52 +0200 Subject: [PATCH 03/19] klint --- .../org/jetbrains/research/testspark/actions/TestSparkAction.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt index 3b08ca009..5a0a96fbc 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt @@ -393,4 +393,4 @@ class TestSparkAction : AnAction() { } override fun getActionUpdateThread(): ActionUpdateThread = ActionUpdateThread.BGT -} \ No newline at end of file +} From c724c62e96a77a02eeb3fb29b4fd2fc7b8a4396e Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 20:17:39 +0200 Subject: [PATCH 04/19] TestClassBuilderHelper refactoring --- build.gradle.kts | 1 + .../testspark/core/data/TestGenerationData.kt | 4 +- .../generation/llm/LLMWithFeedbackCycle.kt | 14 +- .../testspark/core/generation/llm/Utils.kt | 48 +- .../generation/llm/network/RequestManager.kt | 12 +- .../generation/llm/prompt/PromptBuilder.kt | 6 +- .../generation/llm/prompt/PromptGenerator.kt | 6 +- .../llm/prompt/configuration/Configuration.kt | 3 +- .../research/testspark/core/test/Language.kt | 8 - .../testspark/core/test/TestCompiler.kt | 60 +- .../testspark/core/test/TestSuiteParser.kt | 2 +- .../core/test/TestSuiteParserStrategy.kt | 14 - .../testspark/core/test/TestsAssembler.kt | 7 +- .../core/test/TestsPersistentStorage.kt | 1 + .../core/test/data/TestCaseGeneratedByLLM.kt | 29 +- .../core/test/data/TestSuiteGeneratedByLLM.kt | 4 +- .../JavaTestCompilationDependencies.kt | 30 - .../test/java/JavaJUnitTestSuiteParser.kt | 21 +- .../test/kotlin/KotlinJUnitTestSuiteParser.kt | 21 +- .../JUnitTestSuiteParserStrategy.kt | 268 ++++----- .../research/testspark/core/utils/Patterns.kt | 10 +- .../kotlin/KotlinJUnitTestSuiteParserTest.kt | 70 ++- .../testspark/java/JavaPsiClassWrapper.kt | 2 + .../research/testspark/java/JavaPsiHelper.kt | 55 +- .../testspark/kotlin/KotlinPsiClassWrapper.kt | 2 + .../testspark/kotlin/KotlinPsiHelper.kt | 63 ++- .../kotlin/KotlinPsiMethodWrapper.kt | 16 + .../testspark/langwrappers/PsiComponents.kt | 36 +- .../testspark/actions/TestSparkAction.kt | 79 +-- .../actions/llm/LLMSampleSelectorFactory.kt | 5 +- .../actions/llm/TestSamplePanelFactory.kt | 4 +- .../testspark/appstarter/TestSparkStarter.kt | 14 +- .../research/testspark/data/CodeType.kt | 8 - .../testspark/data/FragmentToTestData.kt | 2 + .../testspark/display/TestCasePanelFactory.kt | 53 +- .../display/TopButtonsPanelFactory.kt | 70 +-- .../testspark/helpers/CoverageHelper.kt | 6 +- .../helpers/JavaClassBuilderHelper.kt | 204 ------- .../research/testspark/helpers/LLMHelper.kt | 26 +- .../CoverageToolWindowDisplayService.kt | 0 .../services/TestCaseDisplayService.kt | 527 +----------------- .../testspark/tools/LibraryPathsProvider.kt | 4 +- .../research/testspark/tools/Pipeline.kt | 45 +- .../testspark/tools/TestCompilerFactory.kt | 17 +- .../research/testspark/tools/TestProcessor.kt | 61 +- .../research/testspark/tools/ToolUtils.kt | 45 +- .../testspark/tools/evosuite/EvoSuite.kt | 2 +- .../generation/EvoSuiteProcessManager.kt | 9 +- .../research/testspark/tools/llm/Llm.kt | 15 +- .../llm/generation/JUnitTestsAssembler.kt | 35 +- .../tools/llm/generation/LLMProcessManager.kt | 75 ++- .../tools/llm/generation/PromptManager.kt | 34 +- .../tools/llm/test/JUnitTestSuitePresenter.kt | 20 +- 53 files changed, 835 insertions(+), 1338 deletions(-) delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/Language.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParserStrategy.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt diff --git a/build.gradle.kts b/build.gradle.kts index 13da233c4..5e6621e29 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -157,6 +157,7 @@ dependencies { // https://mvnrepository.com/artifact/org.mockito/mockito-all testImplementation("org.mockito:mockito-all:1.10.19") + testImplementation("org.mockito.kotlin:mockito-kotlin:5.1.0") // https://mvnrepository.com/artifact/net.jqwik/jqwik testImplementation("net.jqwik:jqwik:1.6.5") diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt index d11f346d5..a35212cb1 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt @@ -16,7 +16,7 @@ data class TestGenerationData( // Code required of imports and package for generated tests var importsCode: MutableSet = mutableSetOf(), - var packageLine: String = "", + var packageName: String = "", var runWith: String = "", var otherInfo: String = "", @@ -37,7 +37,7 @@ data class TestGenerationData( resultName = "" fileUrl = "" importsCode = mutableSetOf() - packageLine = "" + packageName = "" runWith = "" otherInfo = "" polyDepthReducing = 0 diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt index 67d4ab653..973b26e7a 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt @@ -10,7 +10,7 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeRed import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.TestsPersistentStorage @@ -45,7 +45,7 @@ data class FeedbackResponse( class LLMWithFeedbackCycle( private val report: Report, - private val language: Language, + private val language: SupportedLanguage, private val initialPromptMessage: String, private val promptSizeReductionStrategy: PromptSizeReductionStrategy, // filename in which the test suite is saved in result path @@ -167,13 +167,15 @@ class LLMWithFeedbackCycle( generatedTestSuite.updateTestCases(compilableTestCases.toMutableList()) } else { for (testCaseIndex in generatedTestSuite.testCases.indices) { - val testCaseFilename = - "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" + val testCaseFilename = when (language) { + SupportedLanguage.Java -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" + SupportedLanguage.Kotlin -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.kt" + } val testCaseRepresentation = testsPresenter.representTestCase(generatedTestSuite, testCaseIndex) val saveFilepath = testStorage.saveGeneratedTest( - generatedTestSuite.packageString, + generatedTestSuite.packageName, testCaseRepresentation, resultPath, testCaseFilename, @@ -184,7 +186,7 @@ class LLMWithFeedbackCycle( } val generatedTestSuitePath: String = testStorage.saveGeneratedTest( - generatedTestSuite.packageString, + generatedTestSuite.packageName, testsPresenter.representTestSuite(generatedTestSuite), resultPath, testSuiteFilename, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt index 2e0e00f01..1942a6a86 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt @@ -4,13 +4,47 @@ import org.jetbrains.research.testspark.core.generation.llm.network.RequestManag import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.utils.javaPackagePattern +import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern import java.util.Locale // TODO: find a better place for the below functions +/** + * Retrieves the package declaration from the given test suite code for any language. + * + * @param testSuiteCode The generated code of the test suite. + * @return The package name extracted from the test suite code, or an empty string if no package declaration was found. + */ +fun getPackageFromTestSuiteCode(testSuiteCode: String?, language: SupportedLanguage): String { + testSuiteCode ?: return "" + return when (language) { + SupportedLanguage.Kotlin -> kotlinPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty() + SupportedLanguage.Java -> javaPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty() + } +} + +/** + * Retrieves the imports code from a given test suite code. + * + * @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned. + * @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result. + * @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned. + */ +fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String?): MutableSet { + testSuiteCode ?: return mutableSetOf() + return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence() + .filter { it.contains("^import".toRegex()) } + .filterNot { it.contains("evosuite".toRegex()) } + .filterNot { it.contains("RunWith".toRegex()) } + // classFQN will be null for the top level function + .filterNot { classFQN != null && it.contains(classFQN.toRegex()) } + .toMutableSet() +} + /** * Returns the generated class name for a given test case. * @@ -39,7 +73,7 @@ fun getClassWithTestCaseName(testCaseName: String): String { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun executeTestCaseModificationRequest( - language: Language, + language: SupportedLanguage, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -50,15 +84,7 @@ fun executeTestCaseModificationRequest( // Update Token information val prompt = "For this test:\n ```\n $testCase\n ```\nPerform the following task: $task" - var packageName = "" - testCase.split("\n")[0].let { - if (it.startsWith("package")) { - packageName = it - .removePrefix("package ") - .removeSuffix(";") - .trim() - } - } + val packageName = getPackageFromTestSuiteCode(testCase, language) val response = requestManager.request( language, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt index eacd2393e..441e51231 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt @@ -7,7 +7,7 @@ import org.jetbrains.research.testspark.core.data.ChatUserMessage import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestsAssembler abstract class RequestManager(var token: String) { @@ -31,7 +31,7 @@ abstract class RequestManager(var token: String) { * @return the generated TestSuite, or null and prompt message */ open fun request( - language: Language, + language: SupportedLanguage, prompt: String, indicator: CustomProgressIndicator, packageName: String, @@ -65,7 +65,7 @@ abstract class RequestManager(var token: String) { open fun processResponse( testsAssembler: TestsAssembler, packageName: String, - language: Language, + language: SupportedLanguage, ): LLMResponse { // save the full response in the chat history val response = testsAssembler.getContent() @@ -78,7 +78,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite() return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) @@ -97,7 +97,7 @@ abstract class RequestManager(var token: String) { open fun processUserFeedbackResponse( testsAssembler: TestsAssembler, packageName: String, - language: Language, + language: SupportedLanguage, ): LLMResponse { val response = testsAssembler.getContent() @@ -108,7 +108,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite() return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 9d645ef9d..036e87a0d 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -78,7 +78,7 @@ internal class PromptBuilder(private var prompt: String) { fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" } for (interestingClass in interestingClasses) { - if (interestingClass.qualifiedName.startsWith("java")) { + if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) { continue } @@ -88,7 +88,9 @@ internal class PromptBuilder(private var prompt: String) { // Skip java methods // TODO: checks for java methods should be done by a caller to make // this class as abstract and language agnostic as possible. - if (method.containingClassQualifiedName.startsWith("java")) { + if (method.containingClassQualifiedName.startsWith("java") || + method.containingClassQualifiedName.startsWith("kotlin") + ) { continue } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt index 3afbd3cff..72340867a 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt @@ -19,7 +19,7 @@ class PromptGenerator( fun generatePromptForClass(interestingClasses: List, testSamplesCode: String): String { val prompt = PromptBuilder(promptTemplates.classPrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName(context.cut.qualifiedName) + .insertName(context.cut!!.qualifiedName) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) .insertCodeUnderTest(context.cut.fullText, context.classesToTest) @@ -44,10 +44,12 @@ class PromptGenerator( method: MethodRepresentation, interestingClassesFromMethod: List, testSamplesCode: String, + packageName: String, ): String { + val name = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}" val prompt = PromptBuilder(promptTemplates.methodPrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName("${context.cut.qualifiedName}.${method.name}") + .insertName(name) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) .insertCodeUnderTest(method.text, context.classesToTest) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt index 4094de1aa..7bc95fc5f 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt @@ -10,7 +10,8 @@ import org.jetbrains.research.testspark.core.data.ClassType * @property polymorphismRelations A map where the key represents a ClassRepresentation object and the value is a list of its detected subclasses. */ data class PromptGenerationContext( - val cut: ClassRepresentation, + // The cut is null when we want to generate tests for top-level function + val cut: ClassRepresentation?, val classesToTest: List, val polymorphismRelations: Map>, val promptConfiguration: PromptConfiguration, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/Language.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/Language.kt deleted file mode 100644 index 605bfaa5b..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/Language.kt +++ /dev/null @@ -1,8 +0,0 @@ -package org.jetbrains.research.testspark.core.test - -/** - * Language ID string should be the same as the language name in com.intellij.lang.Language - */ -enum class Language(val languageId: String) { - Java("JAVA"), Kotlin("Kotlin") -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt index bc4d40617..b49281aaf 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt @@ -1,32 +1,24 @@ package org.jetbrains.research.testspark.core.test -import io.github.oshai.kotlinlogging.KotlinLogging import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.CommandLineRunner import org.jetbrains.research.testspark.core.utils.DataFilesUtil -import java.io.File data class TestCasesCompilationResult( val allTestCasesCompilable: Boolean, val compilableTestCases: MutableSet, ) -/** - * TestCompiler is a class that is responsible for compiling generated test cases using the proper javac. - * It provides methods for compiling test cases and code files. - */ -open class TestCompiler( - private val javaHomeDirectoryPath: String, +abstract class TestCompiler( private val libPaths: List, private val junitLibPaths: List, ) { - private val log = KotlinLogging.logger { this::class.java } - /** - * Compiles the generated files with test cases using the proper javac. + * Compiles a list of test cases and returns the compilation result. * - * @return true if all the provided test cases are successfully compiled, - * otherwise returns false. + * @param generatedTestCasesPaths A list of file paths where the generated test cases are located. + * @param buildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case. + * @param testCases A mutable list of `TestCaseGeneratedByLLM` objects representing the test cases to be compiled. + * @return A `TestCasesCompilationResult` object containing the overall compilation success status and a set of compilable test cases. */ fun compileTestCases( generatedTestCasesPaths: List, @@ -51,45 +43,11 @@ open class TestCompiler( * Compiles the code at the specified path using the provided project build path. * * @param path The path of the code file to compile. - * @param projectBuildPath The project build path to use during compilation. + * @param projectBuildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case. * @return A pair containing a boolean value indicating whether the compilation was successful (true) or not (false), * and a string message describing any error encountered during compilation. */ - fun compileCode(path: String, projectBuildPath: String): Pair { - // find the proper javac - val javaCompile = File(javaHomeDirectoryPath).walk() - .filter { - val isCompilerName = if (DataFilesUtil.isWindows()) it.name.equals("javac.exe") else it.name.equals("javac") - isCompilerName && it.isFile - } - .firstOrNull() - - if (javaCompile == null) { - val msg = "Cannot find java compiler 'javac' at '$javaHomeDirectoryPath'" - log.error { msg } - throw RuntimeException(msg) - } - - println("javac found at '${javaCompile.absolutePath}'") - - // compile file - val errorMsg = CommandLineRunner.run( - arrayListOf( - javaCompile.absolutePath, - "-cp", - "\"${getPath(projectBuildPath)}\"", - path, - ), - ) - - log.info { "Error message: '$errorMsg'" } - - // create .class file path - val classFilePath = path.replace(".java", ".class") - - // check is .class file exists - return Pair(File(classFilePath).exists(), errorMsg) - } + abstract fun compileCode(path: String, projectBuildPath: String): Pair /** * Generates the path for the command by concatenating the necessary paths. @@ -97,7 +55,7 @@ open class TestCompiler( * @param buildPath The path of the build file. * @return The generated path as a string. */ - fun getPath(buildPath: String): String { + fun getClassPaths(buildPath: String): String { // create the path for the command val separator = DataFilesUtil.classpathSeparator val dependencyLibPath = libPaths.joinToString(separator.toString()) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt index a32baed03..60c4016d4 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt @@ -11,7 +11,7 @@ data class TestCaseParseResult( interface TestSuiteParser { /** - * Extracts test cases from raw text and generates a test suite using the given package name. + * Extracts test cases from raw text and generates a test suite. * * @param rawText The raw text provided by the LLM that contains the generated test cases. * @return A GeneratedTestSuite instance containing the extracted test cases. diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParserStrategy.kt deleted file mode 100644 index 83e34370c..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParserStrategy.kt +++ /dev/null @@ -1,14 +0,0 @@ -package org.jetbrains.research.testspark.core.test - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM - -interface TestSuiteParserStrategy { - fun parseTestSuite( - rawText: String, - junitVersion: JUnitVersion, - importPattern: Regex, - packageName: String, - testNamePattern: String, - ): TestSuiteGeneratedByLLM? -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt index dba5ab859..0d9c672de 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt @@ -32,10 +32,9 @@ abstract class TestsAssembler { } /** - * Extracts test cases from raw text and generates a TestSuite using the given package name. + * Extracts test cases from raw text and generates a TestSuite. * - * @param packageName The package name to be set in the generated TestSuite. - * @return A TestSuiteGeneratedByLLM object containing the extracted test cases and package name. + * @return A TestSuiteGeneratedByLLM object containing information about the extracted test cases. */ - abstract fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? + abstract fun assembleTestSuite(): TestSuiteGeneratedByLLM? } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt index 1673fea4a..b9d50132c 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt @@ -4,6 +4,7 @@ package org.jetbrains.research.testspark.core.test * The TestPersistentStorage interface represents a contract for saving generated tests to a specified file system location. */ interface TestsPersistentStorage { + /** * Save the generated tests to a specified directory. * diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt index 6ef9f6907..2a565e82e 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt @@ -1,5 +1,7 @@ package org.jetbrains.research.testspark.core.test.data +import org.jetbrains.research.testspark.core.test.TestBodyPrinter + /** * * Represents a test case generated by LLM. @@ -11,6 +13,7 @@ data class TestCaseGeneratedByLLM( var expectedException: String = "", var throwsException: String = "", var lines: MutableList = mutableListOf(), + val printTestBodyStrategy: TestBodyPrinter, ) { /** @@ -104,31 +107,7 @@ data class TestCaseGeneratedByLLM( * @return a string containing the body of test case */ private fun printTestBody(testInitiatedText: String): String { - var testFullText = testInitiatedText - - // start writing the test signature - testFullText += "\n\tpublic void $name() " - - // add throws exception if exists - if (throwsException.isNotBlank()) { - testFullText += "throws $throwsException" - } - - // start writing the test lines - testFullText += "{\n" - - // write each line - lines.forEach { line -> - testFullText += when (line.type) { - TestLineType.BREAK -> "\t\t\n" - else -> "\t\t${line.text}\n" - } - } - - // close test case - testFullText += "\t}\n" - - return testFullText + return printTestBodyStrategy.printTestBody(testInitiatedText, lines, throwsException, name) } /** diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt index 211063bb7..4fac9b8b9 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt @@ -4,12 +4,12 @@ package org.jetbrains.research.testspark.core.test.data * Represents a test suite generated by LLM. * * @property imports The set of import statements in the test suite. - * @property packageString The package string of the test suite. + * @property packageName The package name of the test suite. * @property testCases The list of test cases in the test suite. */ data class TestSuiteGeneratedByLLM( var imports: Set = emptySet(), - var packageString: String = "", + var packageName: String = "", var runWith: String = "", var otherInfo: String = "", var testCases: MutableList = mutableListOf(), diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt deleted file mode 100644 index 2e78b0b50..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt +++ /dev/null @@ -1,30 +0,0 @@ -package org.jetbrains.research.testspark.core.test.data.dependencies - -import org.jetbrains.research.testspark.core.data.JarLibraryDescriptor - -/** - * The class represents a list of dependencies required for java test compilation. - * The libraries listed are used during test suite/test case compilation. - */ -class JavaTestCompilationDependencies { - companion object { - fun getJarDescriptors() = listOf( - JarLibraryDescriptor( - "mockito-core-5.0.0.jar", - "https://repo1.maven.org/maven2/org/mockito/mockito-core/5.0.0/mockito-core-5.0.0.jar", - ), - JarLibraryDescriptor( - "hamcrest-core-1.3.jar", - "https://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", - ), - JarLibraryDescriptor( - "byte-buddy-1.14.6.jar", - "https://repo1.maven.org/maven2/net/bytebuddy/byte-buddy/1.14.6/byte-buddy-1.14.6.jar", - ), - JarLibraryDescriptor( - "byte-buddy-agent-1.14.6.jar", - "https://repo1.maven.org/maven2/net/bytebuddy/byte-buddy-agent/1.14.6/byte-buddy-agent-1.14.6.jar", - ), - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt index 6a2ad88b3..279badc57 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt @@ -1,23 +1,32 @@ package org.jetbrains.research.testspark.core.test.java import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter import org.jetbrains.research.testspark.core.test.TestSuiteParser -import org.jetbrains.research.testspark.core.test.TestSuiteParserStrategy import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy +import org.jetbrains.research.testspark.core.utils.javaImportPattern class JavaJUnitTestSuiteParser( - private val packageName: String, + private var packageName: String, private val junitVersion: JUnitVersion, - private val importPattern: Regex, - private val parsingStrategy: TestSuiteParserStrategy, + private val testBodyPrinter: TestBodyPrinter, ) : TestSuiteParser { override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { - return parsingStrategy.parseTestSuite( + val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Java) + if (packageInsideTestText.isNotBlank()) { + packageName = packageInsideTestText + } + + return JUnitTestSuiteParserStrategy.parseJUnitTestSuite( rawText, junitVersion, - importPattern, + javaImportPattern, packageName, testNamePattern = "void", + testBodyPrinter, ) } } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt index 91911ea8e..18b164810 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt @@ -1,23 +1,32 @@ package org.jetbrains.research.testspark.core.test.kotlin import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter import org.jetbrains.research.testspark.core.test.TestSuiteParser -import org.jetbrains.research.testspark.core.test.TestSuiteParserStrategy import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy +import org.jetbrains.research.testspark.core.utils.kotlinImportPattern class KotlinJUnitTestSuiteParser( - private val packageName: String, + private var packageName: String, private val junitVersion: JUnitVersion, - private val importPattern: Regex, - private val parsingStrategy: TestSuiteParserStrategy, + private val testBodyPrinter: TestBodyPrinter, ) : TestSuiteParser { override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { - return parsingStrategy.parseTestSuite( + val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Kotlin) + if (packageInsideTestText.isNotBlank()) { + packageName = packageInsideTestText + } + + return JUnitTestSuiteParserStrategy.parseJUnitTestSuite( rawText, junitVersion, - importPattern, + kotlinImportPattern, packageName, testNamePattern = "fun", + testBodyPrinter, ) } } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt index 85cf70681..7bc818cd0 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt @@ -1,167 +1,175 @@ package org.jetbrains.research.testspark.core.test.strategies import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.TestBodyPrinter import org.jetbrains.research.testspark.core.test.TestCaseParseResult -import org.jetbrains.research.testspark.core.test.TestSuiteParserStrategy import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestLine import org.jetbrains.research.testspark.core.test.data.TestLineType import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -class JUnitTestSuiteParserStrategy : TestSuiteParserStrategy { - override fun parseTestSuite( - rawText: String, - junitVersion: JUnitVersion, - importPattern: Regex, - packageName: String, - testNamePattern: String, - ): TestSuiteGeneratedByLLM? { - if (rawText.isBlank()) { - return null - } - - try { - val rawCode = if (rawText.contains("```")) rawText.split("```")[1] else rawText +class JUnitTestSuiteParserStrategy { + companion object { + fun parseJUnitTestSuite( + rawText: String, + junitVersion: JUnitVersion, + importPattern: Regex, + packageName: String, + testNamePattern: String, + printTestBodyStrategy: TestBodyPrinter, + ): TestSuiteGeneratedByLLM? { + if (rawText.isBlank()) { + return null + } - // save imports - val imports = importPattern.findAll(rawCode) - .map { it.groupValues[0] } - .toSet() + try { + val rawCode = if (rawText.contains("```")) rawText.split("```")[1] else rawText - // save RunWith - val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" + // save imports + val imports = importPattern.findAll(rawCode) + .map { it.groupValues[0] } + .toSet() - val testSet: MutableList = rawCode.split("@Test").toMutableList() + // save RunWith + val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" - // save annotations and pre-set methods - val otherInfo: String = run { - val otherInfoList = testSet.removeAt(0).split("{").toMutableList() - otherInfoList.removeFirst() - val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" - otherInfo.ifBlank { "" } - } + val testSet: MutableList = rawCode.split("@Test").toMutableList() - // Save the main test cases - val testCases: MutableList = mutableListOf() - val testCaseParser = JUnitTestCaseParser() + // save annotations and pre-set methods + val otherInfo: String = run { + val otherInfoList = testSet.removeAt(0).split("{").toMutableList() + otherInfoList.removeFirst() + val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" + otherInfo.ifBlank { "" } + } - testSet.forEach ca@{ - val rawTest = "@Test$it" + // Save the main test cases + val testCases: MutableList = mutableListOf() + val testCaseParser = JUnitTestCaseParser() - val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) - val result: TestCaseParseResult = - testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern) + testSet.forEach ca@{ + val rawTest = "@Test$it" - if (result.errorOccurred) { - println("WARNING: ${result.errorMessage}") - return@ca - } + val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) + val result: TestCaseParseResult = + testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern, printTestBodyStrategy) - val currentTest = result.testCase!! + if (result.errorOccurred) { + println("WARNING: ${result.errorMessage}") + return@ca + } - // TODO: make logging work - // log.info("New test case: $currentTest") - println("New test case: $currentTest") + val currentTest = result.testCase!! - testCases.add(currentTest) - } + // TODO: make logging work + // log.info("New test case: $currentTest") - val testSuite = TestSuiteGeneratedByLLM( - imports = imports, - packageString = packageName, - runWith = runWith, - otherInfo = otherInfo, - testCases = testCases, - ) + testCases.add(currentTest) + } - return testSuite - } catch (e: Exception) { - return null + val testSuite = TestSuiteGeneratedByLLM( + imports = imports, + packageName = packageName, + runWith = runWith, + otherInfo = otherInfo, + testCases = testCases, + ) + + return testSuite + } catch (e: Exception) { + return null + } } } -} - -private class JUnitTestCaseParser { - fun parse(rawTest: String, isLastTestCaseInTestSuite: Boolean, testNamePattern: String): TestCaseParseResult { - var expectedException = "" - var throwsException = "" - val testLines: MutableList = mutableListOf() - // Get expected Exception - if (rawTest.startsWith("@Test(expected =")) { - expectedException = rawTest.split(")")[0].trim() - } + private class JUnitTestCaseParser { + fun parse( + rawTest: String, + isLastTestCaseInTestSuite: Boolean, + testNamePattern: String, + printTestBodyStrategy: TestBodyPrinter, + ): TestCaseParseResult { + var expectedException = "" + var throwsException = "" + val testLines: MutableList = mutableListOf() + + // Get expected Exception + if (rawTest.startsWith("@Test(expected =")) { + expectedException = rawTest.split(")")[0].trim() + } - // Get unexpected exceptions - /* Each test case should follow fun {...} - Tests do not return anything so it is safe to consider that void always appears before test case name - */ - if (!rawTest.contains(testNamePattern)) { - return TestCaseParseResult( - testCase = null, - errorMessage = "The raw Test does not contain $testNamePattern:\n $rawTest", - errorOccurred = true, - ) - } - val interestingPartOfSignature = rawTest.split(testNamePattern)[1] - .split("{")[0] - .split("()")[1] - .trim() + // Get unexpected exceptions + /* Each test case should follow fun {...} + Tests do not return anything so it is safe to consider that void always appears before test case name + */ + if (!rawTest.contains(testNamePattern)) { + return TestCaseParseResult( + testCase = null, + errorMessage = "The raw Test does not contain $testNamePattern:\n $rawTest", + errorOccurred = true, + ) + } + val interestingPartOfSignature = rawTest.split(testNamePattern)[1] + .split("{")[0] + .split("()")[1] + .trim() - if (interestingPartOfSignature.contains("throws")) { - throwsException = interestingPartOfSignature.split("throws")[1].trim() - } + if (interestingPartOfSignature.contains("throws")) { + throwsException = interestingPartOfSignature.split("throws")[1].trim() + } - // Get test name - val testName: String = rawTest.split(testNamePattern)[1] - .split("()")[0] - .trim() - - // Get test body and remove opening bracket - var testBody = rawTest.split("{").toMutableList().apply { removeFirst() } - .joinToString("{").trim() - - // remove closing bracket - val tempList = testBody.split("}").toMutableList() - tempList.removeLast() - - if (isLastTestCaseInTestSuite) { - // it is the last test, thus we should remove another closing bracket - if (tempList.isNotEmpty()) { - tempList.removeLast() - } else { - println("WARNING: the final test does not have the enclosing bracket:\n $testBody") + // Get test name + val testName: String = rawTest.split(testNamePattern)[1] + .split("()")[0] + .trim() + + // Get test body and remove opening bracket + var testBody = rawTest.split("{").toMutableList().apply { removeFirst() } + .joinToString("{").trim() + + // remove closing bracket + val tempList = testBody.split("}").toMutableList() + tempList.removeLast() + + if (isLastTestCaseInTestSuite) { + // it is the last test, thus we should remove another closing bracket + if (tempList.isNotEmpty()) { + tempList.removeLast() + } else { + println("WARNING: the final test does not have the enclosing bracket:\n $testBody") + } } - } - testBody = tempList.joinToString("}") + testBody = tempList.joinToString("}") - // Save each line - val rawLines = testBody.split("\n").toMutableList() - rawLines.forEach { rawLine -> - val line = rawLine.trim() + // Save each line + val rawLines = testBody.split("\n").toMutableList() + rawLines.forEach { rawLine -> + val line = rawLine.trim() - val type: TestLineType = when { - line.startsWith("//") -> TestLineType.COMMENT - line.isBlank() -> TestLineType.BREAK - line.lowercase().startsWith("assert") -> TestLineType.ASSERTION - else -> TestLineType.CODE + val type: TestLineType = when { + line.startsWith("//") -> TestLineType.COMMENT + line.isBlank() -> TestLineType.BREAK + line.lowercase().startsWith("assert") -> TestLineType.ASSERTION + else -> TestLineType.CODE + } + + testLines.add(TestLine(type, line)) } - testLines.add(TestLine(type, line)) - } + val currentTest = TestCaseGeneratedByLLM( + name = testName, + expectedException = expectedException, + throwsException = throwsException, + lines = testLines, + printTestBodyStrategy = printTestBodyStrategy, + ) - val currentTest = TestCaseGeneratedByLLM( - name = testName, - expectedException = expectedException, - throwsException = throwsException, - lines = testLines, - ) - - return TestCaseParseResult( - testCase = currentTest, - errorMessage = "", - errorOccurred = false, - ) + return TestCaseParseResult( + testCase = currentTest, + errorMessage = "", + errorOccurred = false, + ) + } } } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt index 95903bf8c..fb1da6841 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt @@ -6,9 +6,17 @@ val javaImportPattern = options = setOf(RegexOption.MULTILINE), ) +/** + * Parse all the possible Kotlin import patterns + * + * import org.mockito.Mockito.`when` + * import kotlin.math.cos + * import kotlin.math.* + * import kotlin.math.PI as piValue + */ val kotlinImportPattern = Regex( - pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?", + pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?(`\\w*`)?", options = setOf(RegexOption.MULTILINE), ) diff --git a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt index b2b5a865c..63fbd0abc 100644 --- a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt +++ b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt @@ -3,8 +3,7 @@ package org.jetbrains.research.testspark.core.test.parsers.kotlin import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM import org.jetbrains.research.testspark.core.test.kotlin.KotlinJUnitTestSuiteParser -import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy -import org.jetbrains.research.testspark.core.utils.kotlinImportPattern +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestBodyPrinter import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertNotNull import org.junit.jupiter.api.Assertions.assertTrue @@ -114,9 +113,9 @@ class KotlinJUnitTestSuiteParserTest { ``` """.trimIndent() - val parsingStrategy = JUnitTestSuiteParserStrategy() + val testBodyPrinter = KotlinTestBodyPrinter() val parser = - KotlinJUnitTestSuiteParser("org.my.package", JUnitVersion.JUnit5, kotlinImportPattern, parsingStrategy) + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) assertNotNull(testSuite) assertTrue(testSuite!!.imports.contains("import org.mockito.Mockito.*")) @@ -143,17 +142,20 @@ class KotlinJUnitTestSuiteParserTest { fun testParseEmptyTestSuite() { val text = """ ```kotlin + package com.example.testsuite + class EmptyTestClass { } ``` """.trimIndent() - val parsingStrategy = JUnitTestSuiteParserStrategy() + val testBodyPrinter = KotlinTestBodyPrinter() val parser = - KotlinJUnitTestSuiteParser("org.my.package", JUnitVersion.JUnit5, kotlinImportPattern, parsingStrategy) + KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter) val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) assertNotNull(testSuite) - assertTrue(testSuite!!.testCases.isEmpty()) + assertEquals(testSuite!!.packageName, "com.example.testsuite") + assertTrue(testSuite.testCases.isEmpty()) } @Test @@ -171,9 +173,9 @@ class KotlinJUnitTestSuiteParserTest { ``` """.trimIndent() - val parsingStrategy = JUnitTestSuiteParserStrategy() + val testBodyPrinter = KotlinTestBodyPrinter() val parser = - KotlinJUnitTestSuiteParser("org.my.package", JUnitVersion.JUnit5, kotlinImportPattern, parsingStrategy) + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) assertNotNull(testSuite) assertEquals(1, testSuite!!.testCases.size) @@ -200,13 +202,59 @@ class KotlinJUnitTestSuiteParserTest { ``` """.trimIndent() - val parsingStrategy = JUnitTestSuiteParserStrategy() + val testBodyPrinter = KotlinTestBodyPrinter() val parser = - KotlinJUnitTestSuiteParser("org.my.package", JUnitVersion.JUnit5, kotlinImportPattern, parsingStrategy) + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) assertNotNull(testSuite) assertEquals(2, testSuite!!.testCases.size) assertEquals("firstTestCase", testSuite.testCases[0].name) assertEquals("secondTestCase", testSuite.testCases[1].name) } + + @Test + fun testParseTwoTestCasesWithDifferentPackage() { + val code1 = """ + ```kotlin + package org.pkg1 + + import org.junit.jupiter.api.Test + + class TestCasesClass1 { + @Test + fun firstTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val code2 = """ + ```kotlin + package org.pkg2 + + import org.junit.jupiter.api.Test + + class 2TestCasesClass { + @Test + fun firstTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter) + + // packageName will be set to 'org.pkg1' + val testSuite1 = parser.parseTestSuite(code1) + + val testSuite2 = parser.parseTestSuite(code2) + + assertNotNull(testSuite1) + assertNotNull(testSuite2) + assertEquals("org.pkg1", testSuite1!!.packageName) + assertEquals("org.pkg2", testSuite2!!.packageName) + } } diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt index f954a1dde..087485827 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt @@ -52,6 +52,8 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { return ClassType.CLASS } + override val rBrace: Int? = psiClass.rBrace?.textRange?.startOffset + override fun searchSubclasses(project: Project): Collection { val scope = GlobalSearchScope.projectScope(project) val query = ClassInheritorsSearch.search(psiClass, scope, false) diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt index d800def9a..823c8542b 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt @@ -4,23 +4,27 @@ import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.editor.Caret +import com.intellij.openapi.editor.Document +import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.project.Project import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiClass import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile +import com.intellij.psi.PsiJavaFile import com.intellij.psi.PsiMethod import com.intellij.psi.util.PsiTreeUtil import com.intellij.psi.util.PsiTypesUtil -import org.jetbrains.research.testspark.core.test.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { - override val language: Language get() = Language.Java + override val language: SupportedLanguage get() = SupportedLanguage.Java private val log = Logger.getInstance(this::class.java) @@ -138,37 +142,48 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { } override fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet { - val interestingPsiClasses = cut.getInterestingPsiClassesWithQualifiedNames(psiMethod) + // The cut is always not null for Java, because all functions are always inside the class + val interestingPsiClasses = cut!!.getInterestingPsiClassesWithQualifiedNames(psiMethod) log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}") return interestingPsiClasses } - override fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? { - val result: ArrayList = arrayListOf() + override fun getCurrentListOfCodeTypes(e: AnActionEvent): List> { + val result: ArrayList> = arrayListOf() val caret: Caret = - e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result.toArray() + e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result val javaPsiClassWrapped = getSurroundingClass(caret.offset) as JavaPsiClassWrapper? val javaPsiMethodWrapped = getSurroundingMethod(caret.offset) as JavaPsiMethodWrapper? val line: Int? = getSurroundingLine(caret.offset) - javaPsiClassWrapped?.let { result.add(getClassHTMLDisplayName(it)) } - javaPsiMethodWrapped?.let { result.add(getMethodHTMLDisplayName(it)) } - line?.let { result.add(getLineHTMLDisplayName(it)) } - - if (javaPsiClassWrapped != null && javaPsiMethodWrapped != null) { - log.info( - "The test can be generated for: \n " + - " 1) Class ${javaPsiClassWrapped.qualifiedName} \n" + - " 2) Method ${javaPsiMethodWrapped.name} \n" + - " 3) Line $line", - ) - } + javaPsiClassWrapped?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } + javaPsiMethodWrapped?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } + line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) } + + log.info( + "The test can be generated for: \n " + + " 1) Class ${javaPsiClassWrapped?.qualifiedName ?: "no class"} \n" + + " 2) Method ${javaPsiMethodWrapped?.name ?: "no method"} \n" + + " 3) Line $line", + ) + + return result + } + + override fun getPackageName(): String { + return (psiFile as PsiJavaFile).packageName + } + + override fun getModuleFromPsiFile(): com.intellij.openapi.module.Module { + return ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!! + } - return result.toArray() + override fun getDocumentFromPsiFile(): Document { + return psiFile.fileDocument } override fun getLineHTMLDisplayName(line: Int) = "line $line" diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt index b1a2d6d92..1b30cc638 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt @@ -81,6 +81,8 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra } } + override val rBrace: Int? = psiClass.body?.rBrace?.textRange?.startOffset + override fun searchSubclasses(project: Project): Collection { val scope = GlobalSearchScope.projectScope(project) val lightClass = psiClass.toLightClass() diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt index 9254a9703..785d1f153 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt @@ -4,6 +4,8 @@ import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.editor.Caret +import com.intellij.openapi.editor.Document +import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.project.Project import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiClass @@ -16,19 +18,21 @@ import org.jetbrains.kotlin.idea.base.psi.kotlinFqName import org.jetbrains.kotlin.idea.caches.resolve.analyze import org.jetbrains.kotlin.psi.KtClass import org.jetbrains.kotlin.psi.KtClassOrObject +import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtFunction import org.jetbrains.kotlin.psi.KtTypeReference import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils import org.jetbrains.kotlin.resolve.lazy.BodyResolveMode -import org.jetbrains.research.testspark.core.test.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper -class KotlinPsiHelper(private var psiFile: PsiFile) : PsiHelper { +class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { - override val language: Language get() = Language.Kotlin + override val language: SupportedLanguage get() = SupportedLanguage.Kotlin private val log = Logger.getInstance(this::class.java) @@ -85,9 +89,10 @@ class KotlinPsiHelper(private var psiFile: PsiFile) : PsiHelper { project: Project, classesToTest: MutableList, caretOffset: Int, - maxPolymorphismDepth: Int, // check if cut has any non-java super class + maxPolymorphismDepth: Int, // check if cut has any non-java superclass ) { - val cutPsiClass = getSurroundingClass(caretOffset)!! + val cutPsiClass = getSurroundingClass(caretOffset) ?: return + // will be null for the top level function var currentPsiClass = cutPsiClass for (index in 0 until maxPolymorphismDepth) { if (!classesToTest.contains(currentPsiClass)) { @@ -143,37 +148,49 @@ class KotlinPsiHelper(private var psiFile: PsiFile) : PsiHelper { } override fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet { - val interestingPsiClasses = cut.getInterestingPsiClassesWithQualifiedNames(psiMethod) + val interestingPsiClasses = + cut?.getInterestingPsiClassesWithQualifiedNames(psiMethod) + ?: (psiMethod as KotlinPsiMethodWrapper).getInterestingPsiClassesWithQualifiedNames() log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}") return interestingPsiClasses } - override fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? { - val result: ArrayList = arrayListOf() + override fun getCurrentListOfCodeTypes(e: AnActionEvent): List> { + val result: ArrayList> = arrayListOf() val caret: Caret = - e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result.toArray() + e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result val ktClass = getSurroundingClass(caret.offset) val ktFunction = getSurroundingMethod(caret.offset) val line: Int? = getSurroundingLine(caret.offset)?.plus(1) - ktClass?.let { result.add(getClassHTMLDisplayName(it)) } - ktFunction?.let { result.add(getMethodHTMLDisplayName(it)) } - line?.let { result.add(getLineHTMLDisplayName(it)) } - - if (ktClass != null && ktFunction != null) { - log.info( - "The test can be generated for: \n " + - " 1) Class ${ktClass.qualifiedName} \n" + - " 2) Method ${ktFunction.name} \n" + - " 3) Line $line", - ) - } + ktClass?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } + ktFunction?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } + line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) } + + log.info( + "The test can be generated for: \n " + + " 1) Class ${ktClass?.qualifiedName ?: "no class"} \n" + + " 2) Method ${ktFunction?.name ?: "no method"} \n" + + " 3) Line $line", + ) + + return result + } + + override fun getPackageName(): String { + return (psiFile as KtFile).packageFqName.asString() + } + + override fun getModuleFromPsiFile(): com.intellij.openapi.module.Module { + return ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!! + } - return result.toArray() + override fun getDocumentFromPsiFile(): Document { + return psiFile.fileDocument } override fun getLineHTMLDisplayName(line: Int) = "line $line" diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt index a142aaaa8..ca4b18feb 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt @@ -68,6 +68,22 @@ class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper { return lineNumber in startLine..endLine } + fun getInterestingPsiClassesWithQualifiedNames(): MutableSet { + val interestingPsiClasses = mutableSetOf() + + psiFunction.valueParameters.forEach { parameter -> + val typeReference = parameter.typeReference + if (typeReference != null) { + val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) + if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { + interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) + } + } + } + + return interestingPsiClasses + } + /** * Generates the return descriptor for a method. * diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt index eb76f5cc3..c5a33bd94 100644 --- a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt @@ -1,11 +1,13 @@ package org.jetbrains.research.testspark.langwrappers import com.intellij.openapi.actionSystem.AnActionEvent +import com.intellij.openapi.editor.Document import com.intellij.openapi.project.Project import com.intellij.openapi.vfs.VirtualFile import com.intellij.psi.PsiFile import org.jetbrains.research.testspark.core.data.ClassType -import org.jetbrains.research.testspark.core.test.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType /** * Interface representing a wrapper for PSI methods, @@ -40,12 +42,14 @@ interface PsiMethodWrapper { * @property name The name of a class * @property qualifiedName The qualified name of the class. * @property text The text of the class. - * @property fullText The source code of the class (with package and imports). - * @property virtualFile - * @property containingFile File where the method is located - * @property superClass The super class of the class * @property methods All methods in the class * @property allMethods All methods in the class and all its superclasses + * @property superClass The super class of the class + * @property virtualFile Virtual file where the class is located + * @property containingFile File where the method is located + * @property fullText The source code of the class (with package and imports). + * @property classType The type of the class + * @property rBrace The offset of the closing brace * */ interface PsiClassWrapper { val name: String @@ -58,6 +62,7 @@ interface PsiClassWrapper { val containingFile: PsiFile val fullText: String val classType: ClassType + val rBrace: Int? /** * Searches for subclasses of the current class within the given project. @@ -81,7 +86,7 @@ interface PsiClassWrapper { * handling the PSI (Program Structure Interface) for different languages. */ interface PsiHelper { - val language: Language + val language: SupportedLanguage /** * Returns the surrounding PsiClass object based on the caret position within the specified PsiFile. @@ -133,7 +138,7 @@ interface PsiHelper { * @return A mutable set of interesting PsiClasses. */ fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet @@ -145,7 +150,7 @@ interface PsiHelper { * The array contains the class display name, method display name (if present), and the line number (if present). * The line number is prefixed with "Line". */ - fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? + fun getCurrentListOfCodeTypes(e: AnActionEvent): List>? /** * Helper for generating method descriptors for methods. @@ -170,6 +175,21 @@ interface PsiHelper { maxPolymorphismDepth: Int, ) + /** + * Get the package name of the file. + */ + fun getPackageName(): String + + /** + * Get the module of the file. + */ + fun getModuleFromPsiFile(): com.intellij.openapi.module.Module + + /** + * Get the module of the file. + */ + fun getDocumentFromPsiFile(): Document? + /** * Gets the display line number. * This is used when displaying the name of a method in the GenerateTestsActionMethod menu entry. diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt index 5a0a96fbc..1f4999661 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt @@ -17,6 +17,7 @@ import org.jetbrains.research.testspark.actions.llm.LLMSetupPanelFactory import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.display.TestSparkIcons import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider @@ -115,14 +116,14 @@ class TestSparkAction : AnAction() { private val caretOffset: Int = e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret!!.offset private val fileUrl = e.dataContext.getData(CommonDataKeys.VIRTUAL_FILE)!!.presentableUrl - private val codeTypeButtons: MutableList = mutableListOf() + private val codeTypeButtons: MutableList> = mutableListOf() private val codeTypeButtonGroup = ButtonGroup() private val nextButton = JButton(PluginLabelsBundle.get("next")) private val cardLayout = CardLayout() private val llmSetupPanelFactory = LLMSetupPanelFactory(e, project) - private val llmSampleSelectorFactory = LLMSampleSelectorFactory(project) + private val llmSampleSelectorFactory = LLMSampleSelectorFactory(project, psiHelper.language) private val evoSuitePanelFactory = EvoSuitePanelFactory(project) init { @@ -198,16 +199,16 @@ class TestSparkAction : AnAction() { testGeneratorPanel.add(llmButton) testGeneratorPanel.add(evoSuiteButton) - for (codeType in codeTypes) { - val button = JRadioButton(codeType as String) - codeTypeButtons.add(button) + for ((codeType, codeTypeName) in codeTypes) { + val button = JRadioButton(codeTypeName) + codeTypeButtons.add(codeType to button) codeTypeButtonGroup.add(button) } val codesToTestPanel = JPanel() codesToTestPanel.add(JLabel("Select the code type:")) - if (codeTypeButtons.size == 1) codeTypeButtons[0].isSelected = true - for (button in codeTypeButtons) codesToTestPanel.add(button) + if (codeTypeButtons.size == 1) codeTypeButtons[0].second.isSelected = true + for ((_, button) in codeTypeButtons) codesToTestPanel.add(button) val middlePanel = FormBuilder.createFormBuilder() .setFormLeftIndent(10) @@ -253,7 +254,7 @@ class TestSparkAction : AnAction() { updateNextButton() } - for (button in codeTypeButtons) { + for ((_, button) in codeTypeButtons) { button.addActionListener { llmSetupPanelFactory.setPromptEditorType(button.text) updateNextButton() @@ -330,33 +331,36 @@ class TestSparkAction : AnAction() { if (!testGenerationController.isGeneratorRunning(project)) { val testSamplesCode = llmSampleSelectorFactory.getTestSamplesCode() - if (codeTypeButtons[0].isSelected) { - tool.generateTestsForClass( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) - } else if (codeTypeButtons[1].isSelected) { - tool.generateTestsForMethod( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) - } else if (codeTypeButtons[2].isSelected) { - tool.generateTestsForLine( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) + for ((codeType, button) in codeTypeButtons) { + if (button.isSelected) { + when (codeType) { + CodeType.CLASS -> tool.generateTestsForClass( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + CodeType.METHOD -> tool.generateTestsForMethod( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + CodeType.LINE -> tool.generateTestsForLine( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + } + break + } } } @@ -376,10 +380,7 @@ class TestSparkAction : AnAction() { */ private fun updateNextButton() { val isTestGeneratorButtonGroupSelected = llmButton.isSelected || evoSuiteButton.isSelected - var isCodeTypeButtonGroupSelected = false - for (button in codeTypeButtons) { - isCodeTypeButtonGroupSelected = isCodeTypeButtonGroupSelected || button.isSelected - } + val isCodeTypeButtonGroupSelected = codeTypeButtons.any { it.second.isSelected } nextButton.isEnabled = isTestGeneratorButtonGroupSelected && isCodeTypeButtonGroupSelected if ((llmButton.isSelected && !llmSettingsState.llmSetupCheckBoxSelected && !llmSettingsState.provideTestSamplesCheckBoxSelected) || diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt index b57ee8d81..b6b77a0ff 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt @@ -4,6 +4,7 @@ import com.intellij.openapi.project.Project import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.helpers.LLMTestSampleHelper import java.awt.Font import javax.swing.ButtonGroup @@ -12,7 +13,7 @@ import javax.swing.JLabel import javax.swing.JPanel import javax.swing.JRadioButton -class LLMSampleSelectorFactory(private val project: Project) : PanelFactory { +class LLMSampleSelectorFactory(private val project: Project, private val language: SupportedLanguage) : PanelFactory { // init components private val selectionTypeButtons: MutableList = mutableListOf( JRadioButton(PluginLabelsBundle.get("provideTestSample")), @@ -128,7 +129,7 @@ class LLMSampleSelectorFactory(private val project: Project) : PanelFactory { } addButton.addActionListener { - val testSamplePanelFactory = TestSamplePanelFactory(project, middlePanel, testNames, initialTestCodes) + val testSamplePanelFactory = TestSamplePanelFactory(project, middlePanel, testNames, initialTestCodes, language) testSamplePanelFactories.add(testSamplePanelFactory) val testSamplePanel = testSamplePanelFactory.getTestSamplePanel() val codeScrollPanel = testSamplePanelFactory.getCodeScrollPanel() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt index 97cf6d49a..251a45f27 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt @@ -10,6 +10,7 @@ import com.intellij.openapi.ui.ComboBox import com.intellij.ui.LanguageTextField import com.intellij.ui.components.JBScrollPane import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.display.IconButtonCreator import org.jetbrains.research.testspark.display.ModifiedLinesGetter import org.jetbrains.research.testspark.display.TestCaseDocumentCreator @@ -25,11 +26,12 @@ class TestSamplePanelFactory( private val middlePanel: JPanel, private val testNames: MutableList, private val initialTestCodes: MutableList, + private val language: SupportedLanguage, ) { // init components private val currentTestCodes = initialTestCodes.toMutableList() private val languageTextField = LanguageTextField( - Language.findLanguageByID("JAVA"), + Language.findLanguageByID(language.languageId), project, initialTestCodes[0], TestCaseDocumentCreator("TestSample"), diff --git a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt index b8b0654d3..499abf1c1 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt @@ -18,7 +18,8 @@ import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.ProjectContext import org.jetbrains.research.testspark.data.llm.JsonEncoding @@ -26,6 +27,7 @@ import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider import org.jetbrains.research.testspark.progress.HeadlessProgressIndicator import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.services.PluginSettingsService +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.Llm @@ -172,6 +174,12 @@ class TestSparkStarter : ApplicationStarter { // Start test generation val indicator = HeadlessProgressIndicator() val errorMonitor = DefaultErrorMonitor() + val testCompiler = TestCompilerFactory.create( + project, + settingsState.junitVersion, + psiHelper.language, + projectSDKPath.toString(), + ) val uiContext = llmProcessManager.runTestGenerator( indicator, FragmentToTestData(CodeType.CLASS), @@ -192,6 +200,7 @@ class TestSparkStarter : ApplicationStarter { classPath, projectContext, projectSDKPath, + testCompiler, ) } else { println("[TestSpark Starter] Test generation failed") @@ -237,6 +246,7 @@ class TestSparkStarter : ApplicationStarter { classPath: String, projectContext: ProjectContext, projectSDKPath: Path, + testCompiler: TestCompiler, ) { val targetDirectory = "$out${File.separator}${packageList.joinToString(File.separator)}" println("Run tests in $targetDirectory") @@ -246,6 +256,7 @@ class TestSparkStarter : ApplicationStarter { var testcaseName = it.nameWithoutExtension.removePrefix("Generated") testcaseName = testcaseName[0].lowercaseChar() + testcaseName.substring(1) // The current test is compiled and is ready to run jacoco + val testExecutionError = TestProcessor(project, projectSDKPath).createXmlFromJacoco( it.nameWithoutExtension, "$targetDirectory${File.separator}jacoco-${it.nameWithoutExtension}", @@ -254,6 +265,7 @@ class TestSparkStarter : ApplicationStarter { packageList.joinToString("."), out, projectContext, + testCompiler, ) // Saving exception (if exists) thrown during the test execution saveException(testcaseName, targetDirectory, testExecutionError) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt deleted file mode 100644 index 8e91aded4..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt +++ /dev/null @@ -1,8 +0,0 @@ -package org.jetbrains.research.testspark.data - -/** -* Enum class, which contains all code elements for which it is possible to request test generation. -*/ -enum class CodeType { - CLASS, METHOD, LINE -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt index 0cf79dddb..3c289bb11 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt @@ -1,5 +1,7 @@ package org.jetbrains.research.testspark.data +import org.jetbrains.research.testspark.core.test.data.CodeType + /** * Data about test objects that require test generators. */ diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt index c3119cea2..99b0ec5ab 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt @@ -25,17 +25,20 @@ import org.jetbrains.research.testspark.core.data.Report import org.jetbrains.research.testspark.core.data.TestCase import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.data.llm.JsonEncoding import org.jetbrains.research.testspark.display.custom.IJProgressIndicator -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper import org.jetbrains.research.testspark.helpers.LLMHelper import org.jetbrains.research.testspark.helpers.ReportHelper import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.services.TestCaseDisplayService import org.jetbrains.research.testspark.services.TestsExecutionResultService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestClassCodeAnalyzerFactory +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.test.JUnitTestSuitePresenter @@ -58,7 +61,7 @@ import javax.swing.border.MatteBorder class TestCasePanelFactory( private val project: Project, - private val language: org.jetbrains.research.testspark.core.test.Language, + private val language: SupportedLanguage, private val testCase: TestCase, editor: Editor, private val checkbox: JCheckBox, @@ -193,7 +196,10 @@ class TestCasePanelFactory( val clipboard: Clipboard = Toolkit.getDefaultToolkit().systemClipboard clipboard.setContents( StringSelection( - project.service().getEditor(testCase.testName)!!.document.text, + when (language) { + SupportedLanguage.Kotlin -> project.service().getEditor(testCase.testName)!!.document.text + SupportedLanguage.Java -> project.service().getEditor(testCase.testName)!!.document.text + }, ), null, ) @@ -386,7 +392,10 @@ class TestCasePanelFactory( } ReportHelper.updateTestCase(project, report, testCase) - project.service().updateUI() + when (language) { + SupportedLanguage.Kotlin -> project.service().updateUI() + SupportedLanguage.Java -> project.service().updateUI() + } } /** @@ -454,12 +463,12 @@ class TestCasePanelFactory( } private fun addTest(testSuite: TestSuiteGeneratedByLLM) { - val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput) + val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput, language) WriteCommandAction.runWriteCommandAction(project) { uiContext.errorMonitor.clear() val code = testSuitePresenter.toString(testSuite) - testCase.testName = JavaClassBuilderHelper.getTestMethodNameFromClassWithTestCase(testCase.testName, code) + testCase.testName = TestClassCodeAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, code) testCase.testCode = code // update numbers @@ -517,15 +526,24 @@ class TestCasePanelFactory( private fun runTest(indicator: CustomProgressIndicator) { indicator.setText("Executing ${testCase.testName}") + val fileName = TestClassCodeAnalyzerFactory.create(language).getFileNameFromTestCaseCode(testCase.testName) + + val testCompiler = TestCompilerFactory.create( + project, + llmSettingsState.junitVersion, + language, + ) + val newTestCase = TestProcessor(project) .processNewTestCase( - "${JavaClassBuilderHelper.getClassFromTestCaseCode(testCase.testCode)}.java", + fileName, testCase.id, testCase.testName, testCase.testCode, - uiContext!!.testGenerationOutput.packageLine, + uiContext!!.testGenerationOutput.packageName, uiContext.testGenerationOutput.resultPath, uiContext.projectContext, + testCompiler, ) testCase.coveredLines = newTestCase.coveredLines @@ -585,13 +603,23 @@ class TestCasePanelFactory( */ private fun remove() { // Remove the test case from the cache - project.service().removeTestCase(testCase.testName) + when (language) { + SupportedLanguage.Kotlin -> project.service().removeTestCase(testCase.testName) + + SupportedLanguage.Java -> project.service().removeTestCase(testCase.testName) + } runTestButton.isEnabled = false isRemoved = true ReportHelper.removeTestCase(project, report, testCase) - project.service().updateUI() + when (language) { + SupportedLanguage.Kotlin -> project.service() + .updateUI() + + SupportedLanguage.Java -> project.service() + .updateUI() + } } /** @@ -663,8 +691,7 @@ class TestCasePanelFactory( * Updates the current test case with the specified test name and test code. */ private fun updateTestCaseInformation() { - testCase.testName = - JavaClassBuilderHelper.getTestMethodNameFromClassWithTestCase(testCase.testName, languageTextField.document.text) + testCase.testName = TestClassCodeAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, languageTextField.document.text) testCase.testCode = languageTextField.document.text } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt index 31cc7b9a6..b8f90918c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt @@ -1,6 +1,5 @@ package org.jetbrains.research.testspark.display -import com.intellij.openapi.components.service import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.progress.Task @@ -8,20 +7,20 @@ import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.display.custom.IJProgressIndicator -import org.jetbrains.research.testspark.services.TestCaseDisplayService +import org.jetbrains.research.testspark.display.strategies.TopButtonsPanelStrategy import java.awt.Dimension import java.util.LinkedList import java.util.Queue import javax.swing.Box import javax.swing.BoxLayout import javax.swing.JButton -import javax.swing.JCheckBox import javax.swing.JLabel import javax.swing.JOptionPane import javax.swing.JPanel -class TopButtonsPanelFactory(private val project: Project) { +class TopButtonsPanelFactory(private val project: Project, private val language: SupportedLanguage) { private var runAllButton: JButton = createRunAllTestButton() private var selectAllButton: JButton = IconButtonCreator.getButton(TestSparkIcons.selectAll, PluginLabelsBundle.get("selectAllTip")) @@ -64,28 +63,26 @@ class TopButtonsPanelFactory(private val project: Project) { * Updates the labels. */ fun updateTopLabels() { - var numberOfPassedTests = 0 - for (testCasePanelFactory in testCasePanelFactories) { - if (testCasePanelFactory.isRemoved()) continue - val error = testCasePanelFactory.getError() - if ((error is String) && error.isEmpty()) { - numberOfPassedTests++ - } - } - testsSelectedLabel.text = String.format( - testsSelectedText, - project.service().getTestsSelected(), - project.service().getTestCasePanels().size, - ) - testsPassedLabel.text = - String.format( + when (language) { + SupportedLanguage.Java -> TopButtonsPanelStrategy.updateTopJavaLabels( + testCasePanelFactories, + testsSelectedLabel, + testsSelectedText, + project, + testsPassedLabel, testsPassedText, - numberOfPassedTests, - project.service().getTestCasePanels().size, + runAllButton, + ) + + SupportedLanguage.Kotlin -> TopButtonsPanelStrategy.updateTopKotlinLabels( + testCasePanelFactories, + testsSelectedLabel, + testsSelectedText, + project, + testsPassedLabel, + testsPassedText, + runAllButton, ) - runAllButton.isEnabled = false - for (testCasePanelFactory in testCasePanelFactories) { - runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled() } } @@ -105,31 +102,20 @@ class TopButtonsPanelFactory(private val project: Project) { * @param selected whether the checkboxes have to be selected or not */ private fun toggleAllCheckboxes(selected: Boolean) { - project.service().getTestCasePanels().forEach { (_, jPanel) -> - val checkBox = jPanel.getComponent(0) as JCheckBox - checkBox.isSelected = selected + when (language) { + SupportedLanguage.Java -> TopButtonsPanelStrategy.toggleAllJavaCheckboxes(selected, project) + SupportedLanguage.Kotlin -> TopButtonsPanelStrategy.toggleAllKotlinCheckboxes(selected, project) } - project.service() - .setTestsSelected(if (selected) project.service().getTestCasePanels().size else 0) } /** * Removes all test cases from the cache and tool window UI. */ private fun removeAllTestCases() { - // Ask the user for the confirmation - val choice = JOptionPane.showConfirmDialog( - null, - PluginMessagesBundle.get("removeAllMessage"), - PluginMessagesBundle.get("confirmationTitle"), - JOptionPane.YES_NO_OPTION, - JOptionPane.QUESTION_MESSAGE, - ) - - // Cancel the operation if the user did not press "Yes" - if (choice == JOptionPane.NO_OPTION) return - - project.service().clear() + when (language) { + SupportedLanguage.Java -> TopButtonsPanelStrategy.removeAllJavaTestCases(project) + SupportedLanguage.Kotlin -> TopButtonsPanelStrategy.removeAllKotlinTestCases(project) + } } /** diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt index bcad7a834..dee6a2b0e 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt @@ -16,7 +16,7 @@ import com.intellij.ui.components.JBLabel import com.intellij.ui.components.JBScrollPane import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.services.EvoSuiteSettingsService -import org.jetbrains.research.testspark.services.TestCaseDisplayService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsState import java.awt.Color import java.awt.Dimension @@ -130,7 +130,7 @@ class CoverageHelper( * @param name name of the test to highlight */ private fun highlightInToolwindow(name: String) { - val testCaseDisplayService = project.service() + val testCaseDisplayService = project.service() testCaseDisplayService.highlightTestCase(name) } @@ -141,7 +141,7 @@ class CoverageHelper( * @param map map of mutant operations -> List of names of tests which cover the mutants */ private fun highlightMutantsInToolwindow(mutantName: String, map: HashMap>) { - val testCaseDisplayService = project.service() + val testCaseDisplayService = project.service() testCaseDisplayService.highlightCoveredMutants(map.getOrPut(mutantName) { ArrayList() }) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt deleted file mode 100644 index 977873bdb..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt +++ /dev/null @@ -1,204 +0,0 @@ -package org.jetbrains.research.testspark.helpers - -import com.github.javaparser.ParseProblemException -import com.github.javaparser.StaticJavaParser -import com.github.javaparser.ast.CompilationUnit -import com.github.javaparser.ast.body.MethodDeclaration -import com.github.javaparser.ast.visitor.VoidVisitorAdapter -import com.intellij.lang.java.JavaLanguage -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.openapi.project.Project -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiFile -import com.intellij.psi.PsiFileFactory -import com.intellij.psi.codeStyle.CodeStyleManager -import org.jetbrains.research.testspark.core.data.TestGenerationData -import java.io.File - -object JavaClassBuilderHelper { - /** - * Generates the code for a test class. - * - * @param className the name of the test class - * @param body the body of the test class - * @return the generated code as a string - */ - fun generateCode( - project: Project, - className: String, - body: String, - imports: Set, - packageString: String, - runWith: String, - otherInfo: String, - testGenerationData: TestGenerationData, - ): String { - var testFullText = printUpperPart(className, imports, packageString, runWith, otherInfo) - - // Add each test (exclude expected exception) - testFullText += body - - // close the test class - testFullText += "}" - - testFullText.replace("\r\n", "\n") - - /** - * for better readability and make the tests shorter, we reduce the number of line breaks: - * when we have three or more sequential \n, reduce it to two. - */ - return formatJavaCode(project, Regex("\n\n\n(\n)*").replace(testFullText, "\n\n"), testGenerationData) - } - - /** - * Returns the upper part of test suite (package name, imports, and test class name) as a string. - * - * @return the upper part of test suite (package name, imports, and test class name) as a string. - */ - private fun printUpperPart( - className: String, - imports: Set, - packageString: String, - runWith: String, - otherInfo: String, - ): String { - var testText = "" - - // Add package - if (packageString.isNotBlank()) { - testText += "package $packageString;\n" - } - - // add imports - imports.forEach { importedElement -> - testText += "$importedElement\n" - } - - testText += "\n" - - // add runWith if exists - if (runWith.isNotBlank()) { - testText += "@RunWith($runWith)\n" - } - // open the test class - testText += "public class $className {\n\n" - - // Add other presets (annotations, non-test functions) - if (otherInfo.isNotBlank()) { - testText += otherInfo - } - - return testText - } - - /** - * Finds the test method from a given class with the specified test case name. - * - * @param code The code of the class containing test methods. - * @return The test method as a string, including the "@Test" annotation. - */ - fun getTestMethodCodeFromClassWithTestCase(code: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(code) - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result += "\t" + method.toString().replace("\n", "\n\t") + "\n\n" - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - val upperCutCode = "\t@Test" + code.split("@Test").last() - var methodStarted = false - var balanceOfBrackets = 0 - for (symbol in upperCutCode) { - result += symbol - if (symbol == '{') { - methodStarted = true - balanceOfBrackets++ - } - if (symbol == '}') { - balanceOfBrackets-- - } - if (methodStarted && balanceOfBrackets == 0) { - break - } - } - return result + "\n" - } - } - - /** - * Retrieves the name of the test method from a given Java class with test cases. - * - * @param oldTestCaseName The old name of test case - * @param code The source code of the Java class with test cases. - * @return The name of the test method. If no test method is found, an empty string is returned. - */ - fun getTestMethodNameFromClassWithTestCase(oldTestCaseName: String, code: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(code) - - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result = method.nameAsString - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - return oldTestCaseName - } - } - - /** - * Retrieves the class name from the given test case code. - * - * @param code The test case code to extract the class name from. - * @return The class name extracted from the test case code. - */ - fun getClassFromTestCaseCode(code: String): String { - val pattern = Regex("public\\s+class\\s+(\\S+)\\s*\\{") - val matchResult = pattern.find(code) - matchResult ?: return "GeneratedTest" - val (className) = matchResult.destructured - return className - } - - /** - * Formats the given Java code using IntelliJ IDEA's code formatting rules. - * - * @param code The Java code to be formatted. - * @return The formatted Java code. - */ - fun formatJavaCode(project: Project, code: String, generatedTestData: TestGenerationData): String { - var result = "" - WriteCommandAction.runWriteCommandAction(project) { - val fileName = generatedTestData.resultPath + File.separatorChar + "Formatted.java" - // create a temporary PsiFile - val psiFile: PsiFile = PsiFileFactory.getInstance(project) - .createFileFromText( - fileName, - JavaLanguage.INSTANCE, - code, - ) - - CodeStyleManager.getInstance(project).reformat(psiFile) - - val document = PsiDocumentManager.getInstance(project).getDocument(psiFile) - result = document?.text ?: code - - File(fileName).delete() - } - - return result - } -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt index 6f4ba38a4..d10525087 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt @@ -12,12 +12,15 @@ import org.jetbrains.research.testspark.core.generation.llm.executeTestCaseModif import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestBodyPrinterFactory +import org.jetbrains.research.testspark.tools.TestSuiteParserFactory +import org.jetbrains.research.testspark.tools.TestsAssemblerFactory import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager -import org.jetbrains.research.testspark.tools.llm.generation.JUnitTestsAssembler import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieInfo import org.jetbrains.research.testspark.tools.llm.generation.grazie.GraziePlatform @@ -244,7 +247,7 @@ object LLMHelper { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun testModificationRequest( - language: Language, + language: SupportedLanguage, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -258,13 +261,28 @@ object LLMHelper { return null } + val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + val testBodyPrinter = TestBodyPrinterFactory.create(language) + val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( + jUnitVersion, + language, + testBodyPrinter, + ) + + val testsAssembler = TestsAssemblerFactory.create( + indicator, + testGenerationOutput, + testSuiteParser, + jUnitVersion, + ) + val testSuite = executeTestCaseModificationRequest( language, testCase, task, indicator, requestManager, - testsAssembler = JUnitTestsAssembler(project, indicator, testGenerationOutput), + testsAssembler, errorMonitor, ) return testSuite diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt index 5784c01cb..6b257f421 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt @@ -1,425 +1,69 @@ package org.jetbrains.research.testspark.services -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.openapi.components.Service -import com.intellij.openapi.components.service -import com.intellij.openapi.fileChooser.FileChooser -import com.intellij.openapi.fileChooser.FileChooserDescriptor -import com.intellij.openapi.fileEditor.FileDocumentManager -import com.intellij.openapi.fileEditor.FileEditorManager -import com.intellij.openapi.fileEditor.OpenFileDescriptor -import com.intellij.openapi.fileEditor.TextEditor -import com.intellij.openapi.project.Project -import com.intellij.openapi.vfs.LocalFileSystem -import com.intellij.openapi.vfs.VirtualFile -import com.intellij.openapi.vfs.VirtualFileManager -import com.intellij.openapi.wm.ToolWindowManager -import com.intellij.psi.PsiClass -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiElementFactory -import com.intellij.psi.PsiJavaFile -import com.intellij.psi.PsiManager -import com.intellij.refactoring.suggested.startOffset +import com.intellij.psi.PsiFile import com.intellij.ui.EditorTextField -import com.intellij.ui.JBColor -import com.intellij.ui.components.JBScrollPane -import com.intellij.ui.content.Content -import com.intellij.ui.content.ContentFactory -import com.intellij.ui.content.ContentManager -import com.intellij.util.containers.stream -import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle -import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle import org.jetbrains.research.testspark.core.data.Report -import org.jetbrains.research.testspark.core.data.TestCase -import org.jetbrains.research.testspark.core.test.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.data.UIContext -import org.jetbrains.research.testspark.display.TestCasePanelFactory -import org.jetbrains.research.testspark.display.TopButtonsPanelFactory -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper -import org.jetbrains.research.testspark.helpers.ReportHelper -import java.awt.BorderLayout -import java.awt.Color -import java.awt.Dimension -import java.io.File -import java.util.Locale -import javax.swing.Box -import javax.swing.BoxLayout -import javax.swing.JButton -import javax.swing.JCheckBox -import javax.swing.JOptionPane +import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import javax.swing.JPanel -import javax.swing.JSeparator -import javax.swing.SwingConstants -@Service(Service.Level.PROJECT) -class TestCaseDisplayService(private val project: Project) { - private var report: Report? = null - - private val unselectedTestCases = HashMap() - - private var mainPanel: JPanel = JPanel() - - private val topButtonsPanelFactory = TopButtonsPanelFactory(project) - - private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) - - private var allTestCasePanel: JPanel = JPanel() - - private var scrollPane: JBScrollPane = JBScrollPane( - allTestCasePanel, - JBScrollPane.VERTICAL_SCROLLBAR_ALWAYS, - JBScrollPane.HORIZONTAL_SCROLLBAR_NEVER, - ) - - private var testCasePanels: HashMap = HashMap() - - private var testsSelected: Int = 0 - - /** - * Default color for the editors in the tool window - */ - private var defaultEditorColor: Color? = null - - /** - * Content Manager to be able to add / remove tabs from tool window - */ - private var contentManager: ContentManager? = null - - /** - * Variable to keep reference to the coverage visualisation content - */ - private var content: Content? = null - - var uiContext: UIContext? = null - - init { - allTestCasePanel.layout = BoxLayout(allTestCasePanel, BoxLayout.Y_AXIS) - mainPanel.layout = BorderLayout() - - mainPanel.add(topButtonsPanelFactory.getPanel(), BorderLayout.NORTH) - mainPanel.add(scrollPane, BorderLayout.CENTER) - - applyButton.isOpaque = false - applyButton.isContentAreaFilled = false - mainPanel.add(applyButton, BorderLayout.SOUTH) - - applyButton.addActionListener { applyTests() } - } +interface TestCaseDisplayService { /** * Fill the panel with the generated test cases. Remove all previously shown test cases. * Add Tests and their names to a List of pairs (used for highlighting) */ - fun displayTestCases(report: Report, uiContext: UIContext, language: Language) { - this.report = report - this.uiContext = uiContext - - val editor = project.service().editor!! - - allTestCasePanel.removeAll() - testCasePanels.clear() - - addSeparator() - - // TestCasePanelFactories array - val testCasePanelFactories = arrayListOf() - - report.testCaseList.values.forEach { - val testCase = it - val testCasePanel = JPanel() - testCasePanel.layout = BorderLayout() - - // Add a checkbox to select the test - val checkbox = JCheckBox() - checkbox.isSelected = true - checkbox.addItemListener { - // Update the number of selected tests - testsSelected -= (1 - 2 * checkbox.isSelected.compareTo(false)) - - if (checkbox.isSelected) { - ReportHelper.selectTestCase(project, report, unselectedTestCases, testCase.id) - } else { - ReportHelper.unselectTestCase(project, report, unselectedTestCases, testCase.id) - } - - updateUI() - } - testCasePanel.add(checkbox, BorderLayout.WEST) - - val testCasePanelFactory = TestCasePanelFactory(project, language, testCase, editor, checkbox, uiContext, report) - testCasePanel.add(testCasePanelFactory.getUpperPanel(), BorderLayout.NORTH) - testCasePanel.add(testCasePanelFactory.getMiddlePanel(), BorderLayout.CENTER) - testCasePanel.add(testCasePanelFactory.getBottomPanel(), BorderLayout.SOUTH) - - testCasePanelFactories.add(testCasePanelFactory) - - testCasePanel.add(Box.createRigidArea(Dimension(12, 0)), BorderLayout.EAST) - - // Add panel to parent panel - testCasePanel.maximumSize = Dimension(Short.MAX_VALUE.toInt(), Short.MAX_VALUE.toInt()) - allTestCasePanel.add(testCasePanel) - addSeparator() - testCasePanels[testCase.testName] = testCasePanel - } - - // Update the number of selected tests (all tests are selected by default) - testsSelected = testCasePanels.size - - topButtonsPanelFactory.setTestCasePanelFactoriesArray(testCasePanelFactories) - topButtonsPanelFactory.updateTopLabels() - - createToolWindowTab() - } + fun displayTestCases(report: Report, uiContext: UIContext, language: SupportedLanguage) /** * Adds a separator to the allTestCasePanel. */ - private fun addSeparator() { - allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) - allTestCasePanel.add(JSeparator(SwingConstants.HORIZONTAL)) - allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) - } + fun addSeparator() /** * Highlight the mini-editor in the tool window whose name corresponds with the name of the test provided * * @param name name of the test whose editor should be highlighted */ - fun highlightTestCase(name: String) { - val myPanel = testCasePanels[name] ?: return - openToolWindowTab() - scrollToPanel(myPanel) - - val editor = getEditor(name) ?: return - val settingsProjectState = project.service().state - val highlightColor = - JBColor( - PluginSettingsBundle.get("colorName"), - Color( - settingsProjectState.colorRed, - settingsProjectState.colorGreen, - settingsProjectState.colorBlue, - 30, - ), - ) - if (editor.background.equals(highlightColor)) return - defaultEditorColor = editor.background - editor.background = highlightColor - returnOriginalEditorBackground(editor) - } + fun highlightTestCase(name: String) /** * Method to open the toolwindow tab with generated tests if not already open. */ - private fun openToolWindowTab() { - val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") - contentManager = toolWindowManager!!.contentManager - if (content != null) { - toolWindowManager.show() - toolWindowManager.contentManager.setSelectedContent(content!!) - } - } + fun openToolWindowTab() /** * Scrolls to the highlighted panel. * * @param myPanel the panel to scroll to */ - private fun scrollToPanel(myPanel: JPanel) { - var sum = 0 - for (component in allTestCasePanel.components) { - if (component == myPanel) { - break - } else { - sum += component.height - } - } - val scroll = scrollPane.verticalScrollBar - scroll.value = (scroll.minimum + scroll.maximum) * sum / allTestCasePanel.height - } + fun scrollToPanel(myPanel: JPanel) /** * Removes all coverage highlighting from the editor. */ - private fun removeAllHighlights() { - project.service().editor?.markupModel?.removeAllHighlighters() - } + fun removeAllHighlights() /** * Reset the provided editors color to the default (initial) one after 10 seconds * @param editor the editor whose color to change */ - private fun returnOriginalEditorBackground(editor: EditorTextField) { - Thread { - Thread.sleep(10000) - editor.background = defaultEditorColor - }.start() - } + fun returnOriginalEditorBackground(editor: EditorTextField) /** * Highlight a range of editors * @param names list of test names to pass to highlight function */ - fun highlightCoveredMutants(names: List) { - names.forEach { - highlightTestCase(it) - } - } + fun highlightCoveredMutants(names: List) /** * Show a dialog where the user can select what test class the tests should be applied to, * and apply the selected tests to the test class. */ - private fun applyTests() { - // Filter the selected test cases - val selectedTestCasePanels = testCasePanels.filter { (it.value.getComponent(0) as JCheckBox).isSelected } - val selectedTestCases = selectedTestCasePanels.map { it.key } - - // Get the test case components (source code of the tests) - val testCaseComponents = selectedTestCases - .map { getEditor(it)!! } - .map { it.document.text } - - // Descriptor for choosing folders and java files - val descriptor = FileChooserDescriptor(true, true, false, false, false, false) - - // Apply filter with folders and java files with main class - WriteCommandAction.runWriteCommandAction(project) { - descriptor.withFileFilter { file -> - file.isDirectory || ( - file.extension?.lowercase(Locale.getDefault()) == "java" && ( - PsiManager.getInstance(project).findFile(file!!) as PsiJavaFile - ).classes.stream().map { it.name } - .toArray() - .contains( - ( - PsiManager.getInstance(project) - .findFile(file) as PsiJavaFile - ).name.removeSuffix(".java"), - ) - ) - } - } - - val fileChooser = FileChooser.chooseFiles( - descriptor, - project, - LocalFileSystem.getInstance().findFileByPath(project.basePath!!), - ) - - /** - * Cancel button pressed - */ - if (fileChooser.isEmpty()) return - - /** - * Chosen files by user - */ - val chosenFile = fileChooser[0] - - /** - * Virtual file of a final java file - */ - var virtualFile: VirtualFile? = null - - /** - * PsiClass of a final java file - */ - var psiClass: PsiClass? = null - - /** - * PsiJavaFile of a final java file - */ - var psiJavaFile: PsiJavaFile? = null - - if (chosenFile.isDirectory) { - // Input new file data - var className: String - var fileName: String - var filePath: String - // Waiting for correct file name input - while (true) { - val jOptionPane = - JOptionPane.showInputDialog( - null, - PluginLabelsBundle.get("optionPaneMessage"), - PluginLabelsBundle.get("optionPaneTitle"), - JOptionPane.PLAIN_MESSAGE, - null, - null, - null, - ) - - // Cancel button pressed - jOptionPane ?: return - - // Get class name from user - className = jOptionPane as String - - // Set file name and file path - fileName = "${className.split('.')[0]}.java" - filePath = "${chosenFile.path}/$fileName" - - // Check the correctness of a class name - if (!Regex("[A-Z][a-zA-Z0-9]*(.java)?").matches(className)) { - showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage")) - continue - } - - // Check the existence of a file with this name - if (File(filePath).exists()) { - showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage")) - continue - } - break - } - - // Create new file and set services of this file - WriteCommandAction.runWriteCommandAction(project) { - chosenFile.createChildData(null, fileName) - virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!! - psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) - psiClass = PsiElementFactory.getInstance(project).createClass(className.split(".")[0]) + fun applyTests() - if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { - psiClass!!.modifierList!!.addAnnotation("RunWith(${uiContext!!.testGenerationOutput.runWith})") - } - - psiJavaFile!!.add(psiClass!!) - } - } else { - // Set services of the chosen file - virtualFile = chosenFile - psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) - psiClass = psiJavaFile!!.classes[ - psiJavaFile!!.classes.stream().map { it.name }.toArray() - .indexOf(psiJavaFile!!.name.removeSuffix(".java")), - ] - } - - // Add tests to the file - WriteCommandAction.runWriteCommandAction(project) { - appendTestsToClass(testCaseComponents, psiClass!!, psiJavaFile!!) - } - - // Remove the selected test cases from the cache and the tool window UI - removeSelectedTestCases(selectedTestCasePanels) - - // Open the file after adding - FileEditorManager.getInstance(project).openTextEditor( - OpenFileDescriptor(project, virtualFile!!), - true, - ) - } - - private fun showErrorWindow(message: String) { - JOptionPane.showMessageDialog( - null, - message, - PluginLabelsBundle.get("errorWindowTitle"), - JOptionPane.ERROR_MESSAGE, - ) - } + fun showErrorWindow(message: String) /** * Retrieve the editor corresponding to a particular test case @@ -427,11 +71,7 @@ class TestCaseDisplayService(private val project: Project) { * @param testCaseName the name of the test case * @return the editor corresponding to the test case, or null if it does not exist */ - fun getEditor(testCaseName: String): EditorTextField? { - val middlePanelComponent = testCasePanels[testCaseName]?.getComponent(2) ?: return null - val middlePanel = middlePanelComponent as JPanel - return (middlePanel.getComponent(1) as JBScrollPane).viewport.view as EditorTextField - } + fun getEditor(testCaseName: String): EditorTextField? /** * Append the provided test cases to the provided class. @@ -440,107 +80,23 @@ class TestCaseDisplayService(private val project: Project) { * @param selectedClass the class which the test cases should be appended to * @param outputFile the output file for tests */ - private fun appendTestsToClass(testCaseComponents: List, selectedClass: PsiClass, outputFile: PsiJavaFile) { - // block document - PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument( - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!, - ) - - // insert tests to a code - testCaseComponents.reversed().forEach { - val testMethodCode = - JavaClassBuilderHelper.getTestMethodCodeFromClassWithTestCase( - JavaClassBuilderHelper.formatJavaCode( - project, - it.replace("\r\n", "\n") - .replace("verifyException(", "// verifyException("), - uiContext!!.testGenerationOutput, - ), - ) - // Fix Windows line separators - .replace("\r\n", "\n") - - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - selectedClass.rBrace!!.textRange.startOffset, - testMethodCode, - ) - } - - // insert other info to a code - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - selectedClass.rBrace!!.textRange.startOffset, - uiContext!!.testGenerationOutput.otherInfo + "\n", - ) - - // insert imports to a code - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - outputFile.importList?.startOffset ?: outputFile.packageStatement?.startOffset ?: 0, - uiContext!!.testGenerationOutput.importsCode.joinToString("\n") + "\n\n", - ) - - // insert package to a code - outputFile.packageStatement ?: PsiDocumentManager.getInstance(project).getDocument(outputFile)!! - .insertString( - 0, - if (uiContext!!.testGenerationOutput.packageLine.isEmpty()) { - "" - } else { - "package ${uiContext!!.testGenerationOutput.packageLine};\n\n" - }, - ) - } + fun appendTestsToClass(testCaseComponents: List, selectedClass: PsiClassWrapper, outputFile: PsiFile) /** * Utility function that returns the editor for a specific file url, * in case it is opened in the IDE */ - fun updateEditorForFileUrl(fileUrl: String) { - val documentManager = FileDocumentManager.getInstance() - // https://intellij-support.jetbrains.com/hc/en-us/community/posts/360004480599/comments/360000703299 - FileEditorManager.getInstance(project).selectedEditors.map { it as TextEditor }.map { it.editor }.map { - val currentFile = documentManager.getFile(it.document) - if (currentFile != null) { - if (currentFile.presentableUrl == fileUrl) { - project.service().editor = it - } - } - } - } + fun updateEditorForFileUrl(fileUrl: String) /** * Creates a new toolWindow tab for the coverage visualisation. */ - private fun createToolWindowTab() { - // Remove generated tests tab from content manager if necessary - val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") - contentManager = toolWindowManager!!.contentManager - if (content != null) { - contentManager!!.removeContent(content!!, true) - } - - // If there is no generated tests tab, make it - val contentFactory: ContentFactory = ContentFactory.getInstance() - content = contentFactory.createContent( - mainPanel, - PluginLabelsBundle.get("generatedTests"), - true, - ) - contentManager!!.addContent(content!!) - - // Focus on generated tests tab and open toolWindow if not opened already - contentManager!!.setSelectedContent(content!!) - toolWindowManager.show() - } + fun createToolWindowTab() /** * Closes the tool window and destroys the content of the tab. */ - private fun closeToolWindow() { - contentManager?.removeContent(content!!, true) - ToolWindowManager.getInstance(project).getToolWindow("TestSpark")?.hide() - val coverageVisualisationService = project.service() - coverageVisualisationService.closeToolWindowTab() - } + fun closeToolWindow() /** * Removes the selected tests from the cache, removes all the highlights from the editor and closes the tool window. @@ -549,37 +105,16 @@ class TestCaseDisplayService(private val project: Project) { * * @param selectedTestCasePanels the panels of the selected tests */ - private fun removeSelectedTestCases(selectedTestCasePanels: Map) { - selectedTestCasePanels.forEach { removeTestCase(it.key) } - removeAllHighlights() - closeToolWindow() - } - - fun clear() { - // Remove the tests - val testCasePanelsToRemove = testCasePanels.toMap() - removeSelectedTestCases(testCasePanelsToRemove) + fun removeSelectedTestCases(selectedTestCasePanels: Map) - topButtonsPanelFactory.clear() - } + fun clear() /** * A helper method to remove a test case from the cache and from the UI. * * @param testCaseName the name of the test */ - fun removeTestCase(testCaseName: String) { - // Update the number of selected test cases if necessary - if ((testCasePanels[testCaseName]!!.getComponent(0) as JCheckBox).isSelected) { - testsSelected-- - } - - // Remove the test panel from the UI - allTestCasePanel.remove(testCasePanels[testCaseName]) - - // Remove the test panel - testCasePanels.remove(testCaseName) - } + fun removeTestCase(testCaseName: String) /** * Updates the user interface of the tool window. @@ -589,36 +124,26 @@ class TestCaseDisplayService(private val project: Project) { * of the topButtonsPanel object. It also checks if there are no more tests remaining * and closes the tool window if that is the case. */ - fun updateUI() { - // Update the UI of the tool window tab - allTestCasePanel.updateUI() - - topButtonsPanelFactory.updateTopLabels() - - // If no more tests are remaining, close the tool window - if (testCasePanels.size == 0) closeToolWindow() - } + fun updateUI() /** * Retrieves the list of test case panels. * * @return The list of test case panels. */ - fun getTestCasePanels() = testCasePanels + fun getTestCasePanels(): HashMap /** * Retrieves the currently selected tests. * * @return The list of tests currently selected. */ - fun getTestsSelected() = testsSelected + fun getTestsSelected(): Int /** * Sets the number of tests selected. * * @param testsSelected The number of tests selected. */ - fun setTestsSelected(testsSelected: Int) { - this.testsSelected = testsSelected - } + fun setTestsSelected(testsSelected: Int) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt index 0cd1b073a..c4310ba61 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt @@ -2,7 +2,7 @@ package org.jetbrains.research.testspark.tools import com.intellij.openapi.application.PathManager import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.dependencies.JavaTestCompilationDependencies +import org.jetbrains.research.testspark.core.test.data.dependencies.TestCompilationDependencies import java.io.File /** @@ -16,7 +16,7 @@ class LibraryPathsProvider { private val sep = File.separatorChar private val libPrefix = "${PathManager.getPluginsPath()}${sep}TestSpark${sep}lib$sep" - fun getTestCompilationLibraryPaths() = JavaTestCompilationDependencies.getJarDescriptors().map { descriptor -> + fun getTestCompilationLibraryPaths() = TestCompilationDependencies.getJarDescriptors().map { descriptor -> "$libPrefix${sep}${descriptor.name}" } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt index aa5b694b7..30ed0ba6b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt @@ -6,12 +6,12 @@ import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.progress.Task import com.intellij.openapi.project.Project -import com.intellij.openapi.roots.ProjectFileIndex import com.intellij.openapi.roots.ProjectRootManager import com.intellij.openapi.util.io.FileUtilRt import org.jetbrains.research.testspark.actions.controllers.TestGenerationController import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.utils.DataFilesUtil import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.ProjectContext @@ -22,6 +22,8 @@ import org.jetbrains.research.testspark.services.CoverageVisualisationService import org.jetbrains.research.testspark.services.EditorService import org.jetbrains.research.testspark.services.TestCaseDisplayService import org.jetbrains.research.testspark.services.TestsExecutionResultService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService import org.jetbrains.research.testspark.tools.template.generation.ProcessManager import java.util.UUID @@ -29,7 +31,7 @@ import java.util.UUID * Pipeline class represents a pipeline for generating tests in a project. * * @param project the project in which the pipeline is executed. - * @param psiHelper The PsiHelper in the context of witch the pipeline is executed. + * @param psiHelper The PsiHelper in the context of which the pipeline is executed. * @param caretOffset the offset of the caret position in the PSI file. * @param fileUrl the URL of the file being processed, if applicable. * @param packageName the package name of the file being processed. @@ -47,7 +49,7 @@ class Pipeline( init { - val cutPsiClass = psiHelper.getSurroundingClass(caretOffset)!! + val cutPsiClass = psiHelper.getSurroundingClass(caretOffset) // get generated test path val testResultDirectory = "${FileUtilRt.getTempDirectory()}${ToolUtils.sep}testSparkResults${ToolUtils.sep}" @@ -57,10 +59,8 @@ class Pipeline( ApplicationManager.getApplication().runWriteAction { projectContext.projectClassPath = ProjectRootManager.getInstance(project).contentRoots.first().path projectContext.fileUrlAsString = fileUrl - projectContext.classFQN = cutPsiClass.qualifiedName - // TODO probably can be made easier - projectContext.cutModule = - ProjectFileIndex.getInstance(project).getModuleForFile(cutPsiClass.virtualFile)!! + cutPsiClass?.let { projectContext.classFQN = it.qualifiedName } + projectContext.cutModule = psiHelper.getModuleFromPsiFile() } generatedTestsData.resultPath = ToolUtils.getResultPath(id, testResultDirectory) @@ -108,14 +108,13 @@ class Pipeline( override fun onFinished() { super.onFinished() testGenerationController.finished() - uiContext?.let { - project.service() - .updateEditorForFileUrl(it.testGenerationOutput.fileUrl) - - if (project.service().editor != null) { - val report = it.testGenerationOutput.testGenerationResultList[0]!! - project.service().displayTestCases(report, it, psiHelper.language) - project.service().showCoverage(report) + when (psiHelper.language) { + SupportedLanguage.Java -> uiContext?.let { + displayTestCase(it) + } + + SupportedLanguage.Kotlin -> uiContext?.let { + displayTestCase(it) } } } @@ -124,8 +123,22 @@ class Pipeline( private fun clear(project: Project) { // should be removed totally! testGenerationController.errorMonitor.clear() - project.service().clear() + when (psiHelper.language) { + SupportedLanguage.Java -> project.service().clear() + SupportedLanguage.Kotlin -> project.service().clear() + } + project.service().clear() project.service().clear() } + + private inline fun displayTestCase(ctx: UIContext) { + project.service().updateEditorForFileUrl(ctx.testGenerationOutput.fileUrl) + + if (project.service().editor != null) { + val report = ctx.testGenerationOutput.testGenerationResultList[0]!! + project.service().displayTestCases(report, ctx, psiHelper.language) + project.service().showCoverage(report) + } + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt index 8680370bd..84b512bb5 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt @@ -3,20 +3,31 @@ package org.jetbrains.research.testspark.tools import com.intellij.openapi.project.Project import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.test.java.JavaTestCompiler +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestCompiler class TestCompilerFactory { companion object { - fun createJavacTestCompiler( + fun create( project: Project, junitVersion: JUnitVersion, + language: SupportedLanguage, javaHomeDirectory: String? = null, ): TestCompiler { - val javaHomePath = javaHomeDirectory ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + val javaSDKHomePath = + javaHomeDirectory ?: ProjectRootManager.getInstance(project).projectSdk?.homeDirectory?.path + ?: throw RuntimeException("Java SDK not configured for the project.") + val libraryPaths = LibraryPathsProvider.getTestCompilationLibraryPaths() val junitLibraryPaths = LibraryPathsProvider.getJUnitLibraryPaths(junitVersion) - return TestCompiler(javaHomePath, libraryPaths, junitLibraryPaths) + // TODO add the warning window that for Java we always need the javaHomeDirectoryPath + return when (language) { + SupportedLanguage.Java -> JavaTestCompiler(libraryPaths, junitLibraryPaths, javaSDKHomePath) + SupportedLanguage.Kotlin -> KotlinTestCompiler(libraryPaths, junitLibraryPaths) + } } } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt index e0a4150b4..d35589357 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt @@ -8,6 +8,7 @@ import com.intellij.openapi.roots.CompilerModuleExtension import com.intellij.openapi.roots.ModuleRootManager import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.core.data.TestCase +import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.utils.CommandLineRunner import org.jetbrains.research.testspark.core.utils.DataFilesUtil @@ -25,16 +26,20 @@ class TestProcessor( val project: Project, givenProjectSDKPath: Path? = null, ) : TestsPersistentStorage { - private val javaHomeDirectory = givenProjectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + private val homeDirectory = + givenProjectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path private val log = Logger.getInstance(this::class.java) private val llmSettingsState: LLMSettingsState get() = project.getService(LLMSettingsService::class.java).state - val testCompiler = TestCompilerFactory.createJavacTestCompiler(project, llmSettingsState.junitVersion, javaHomeDirectory) - - override fun saveGeneratedTest(packageString: String, code: String, resultPath: String, testFileName: String): String { + override fun saveGeneratedTest( + packageString: String, + code: String, + resultPath: String, + testFileName: String, + ): String { // Generate the final path for the generated tests var generatedTestPath = "$resultPath${File.separatorChar}" packageString.split(".").forEach { directory -> @@ -69,14 +74,10 @@ class TestProcessor( generatedTestPackage: String, resultPath: String, projectContext: ProjectContext, + testCompiler: TestCompiler, ): String { // find the proper javac - val javaRunner = File(javaHomeDirectory).walk() - .filter { - val isJavaName = if (DataFilesUtil.isWindows()) it.name.equals("java.exe") else it.name.equals("java") - isJavaName && it.isFile - } - .first() + val javaRunner = findJavaCompilerInDirectory(homeDirectory) // JaCoCo libs val jacocoAgentLibraryPath = "\"${LibraryPathsProvider.getJacocoAgentLibraryPath()}\"" val jacocoCLILibraryPath = "\"${LibraryPathsProvider.getJacocoCliLibraryPath()}\"" @@ -90,13 +91,21 @@ class TestProcessor( val junitVersion = llmSettingsState.junitVersion.version // run the test method with jacoco agent + log.info("[TestProcessor] Executing $name") val junitRunnerLibraryPath = LibraryPathsProvider.getJUnitRunnerLibraryPath() + // classFQN will be null for the top level function + val javaAgentFlag = + if (projectContext.classFQN != null) { + "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false,includes=${projectContext.classFQN}" + } else { + "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false" + } val testExecutionError = CommandLineRunner.run( arrayListOf( javaRunner.absolutePath, - "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false,includes=${projectContext.classFQN}", + javaAgentFlag, "-cp", - "\"${testCompiler.getPath(projectBuildPath)}${DataFilesUtil.classpathSeparator}${junitRunnerLibraryPath}${DataFilesUtil.classpathSeparator}$resultPath\"", + "\"${testCompiler.getClassPaths(projectBuildPath)}${DataFilesUtil.classpathSeparator}${junitRunnerLibraryPath}${DataFilesUtil.classpathSeparator}$resultPath\"", "org.jetbrains.research.SingleJUnitTestRunner$junitVersion", name, ), @@ -148,9 +157,10 @@ class TestProcessor( testId: Int, testName: String, testCode: String, - packageLine: String, + packageName: String, resultPath: String, projectContext: ProjectContext, + testCompiler: TestCompiler, ): TestCase { // get buildPath var buildPath: String = ProjectRootManager.getInstance(project).contentRoots.first().path @@ -161,7 +171,7 @@ class TestProcessor( // save new test to file val generatedTestPath: String = saveGeneratedTest( - packageLine, + packageName, testCode, resultPath, fileName, @@ -179,9 +189,10 @@ class TestProcessor( dataFileName, testName, buildPath, - packageLine, + packageName, resultPath, projectContext, + testCompiler, ) if (!File("$dataFileName.xml").exists()) { @@ -230,7 +241,8 @@ class TestProcessor( frames.removeFirst() frames.forEach { frame -> - if (frame.contains(projectContext.classFQN!!)) { + // classFQN will be null for the top level function + if (projectContext.classFQN != null && frame.contains(projectContext.classFQN!!)) { val coveredLineNumber = frame.split(":")[1].replace(")", "").toIntOrNull() if (coveredLineNumber != null) { result.add(coveredLineNumber) @@ -274,7 +286,8 @@ class TestProcessor( children("counter") {} } children("sourcefile") { - isCorrectSourceFile = this.attributes.getValue("name") == projectContext.fileUrlAsString!!.split(File.separatorChar).last() + isCorrectSourceFile = + this.attributes.getValue("name") == projectContext.fileUrlAsString!!.split(File.separatorChar).last() children("line") { if (isCorrectSourceFile && this.attributes.getValue("mi") == "0") { setOfLines.add(this.attributes.getValue("nr").toInt()) @@ -295,4 +308,18 @@ class TestProcessor( return TestCase(testCaseId, testCaseName, testCaseCode, setOfLines) } + + /** + * Finds 'javac' compiler (both on Unix & Windows) + * starting from the provided directory. + */ + private fun findJavaCompilerInDirectory(homeDirectory: String): File { + return File(homeDirectory).walk() + .filter { + val isJavaName = + if (DataFilesUtil.isWindows()) it.name.equals("java.exe") else it.name.equals("java") + isJavaName && it.isFile + } + .first() + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt index 3ba26b9c5..a7ef25eb2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt @@ -11,9 +11,9 @@ import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.utils.DataFilesUtil import org.jetbrains.research.testspark.data.IJTestCase -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper import org.jetbrains.research.testspark.services.TestsExecutionResultService import java.io.File @@ -21,68 +21,37 @@ object ToolUtils { val sep = File.separatorChar val pathSep = File.pathSeparatorChar - /** - * Retrieves the imports code from a given test suite code. - * - * @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned. - * @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result. - * @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned. - */ - fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String): MutableSet { - testSuiteCode ?: return mutableSetOf() - return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence() - .filter { it.contains("^import".toRegex()) } - .filterNot { it.contains("evosuite".toRegex()) } - .filterNot { it.contains("RunWith".toRegex()) } - .filterNot { it.contains(classFQN.toRegex()) }.toMutableSet() - } - - /** - * Retrieves the package declaration from the given test suite code. - * - * @param testSuiteCode The generated code of the test suite. - * @return The package declaration extracted from the test suite code, or an empty string if no package declaration was found. - */ -// get package from a generated code - fun getPackageFromTestSuiteCode(testSuiteCode: String?): String { - testSuiteCode ?: return "" - if (!testSuiteCode.contains("package")) return "" - val result = testSuiteCode.replace("\r\n", "\n").split("\n") - .filter { it.contains("^package".toRegex()) }.joinToString("").split("package ")[1].split(";")[0] - if (result.isBlank()) return "" - return result - } - /** * Saves the data related to test generation in the specified project's workspace. * * @param project The project in which the test generation data will be saved. * @param report The report object to be added to the test generation result list. - * @param packageLine The package declaration line of the test generation data. + * @param packageName The package declaration line of the test generation data. * @param importsCode The import statements code of the test generation data. */ fun saveData( project: Project, report: Report, - packageLine: String, + packageName: String, importsCode: MutableSet, fileUrl: String, generatedTestData: TestGenerationData, + language: SupportedLanguage = SupportedLanguage.Java, ) { generatedTestData.fileUrl = fileUrl - generatedTestData.packageLine = packageLine + generatedTestData.packageName = packageName generatedTestData.importsCode.addAll(importsCode) project.service().initExecutionResult(report.testCaseList.values.map { it.id }) for (testCase in report.testCaseList.values) { val code = testCase.testCode - testCase.testCode = JavaClassBuilderHelper.generateCode( + testCase.testCode = TestClassCodeGeneratorFactory.create(language).generateCode( project, getClassWithTestCaseName(testCase.testName), code, generatedTestData.importsCode, - generatedTestData.packageLine, + generatedTestData.packageName, generatedTestData.runWith, generatedTestData.otherInfo, generatedTestData, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt index 46b982ac1..4e4c75a75 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt @@ -5,7 +5,7 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.actions.controllers.TestGenerationController -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt index c1e5e6560..8c180f9df 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt @@ -15,10 +15,13 @@ import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteDefaultsBundle import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteMessagesBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.core.utils.CommandLineRunner -import org.jetbrains.research.testspark.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.IJReport import org.jetbrains.research.testspark.data.ProjectContext @@ -200,8 +203,8 @@ class EvoSuiteProcessManager( ToolUtils.saveData( project, IJReport(testGenerationResult), - ToolUtils.getPackageFromTestSuiteCode(testGenerationResult.testSuiteCode), - ToolUtils.getImportsCodeFromTestSuiteCode(testGenerationResult.testSuiteCode, classFQN), + getPackageFromTestSuiteCode(testGenerationResult.testSuiteCode, SupportedLanguage.Java), + getImportsCodeFromTestSuiteCode(testGenerationResult.testSuiteCode, classFQN), projectContext.fileUrlAsString!!, generatedTestsData, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt index 01f16176c..0379bb183 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt @@ -1,11 +1,12 @@ package org.jetbrains.research.testspark.tools.llm import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.actions.controllers.TestGenerationController import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.helpers.LLMHelper import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper @@ -23,6 +24,8 @@ import java.nio.file.Path */ class Llm(override val name: String = "LLM") : Tool { + private val log = Logger.getInstance(this::class.java) + /** * Returns an instance of the LLMProcessManager. * @@ -74,6 +77,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests fo CLASS was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -107,6 +111,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests fo METHOD was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -141,6 +146,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests fo LINE was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -174,9 +180,10 @@ class Llm(override val name: String = "LLM") : Tool { fileUrl: String?, testGenerationController: TestGenerationController, ): Pipeline { - val cutPsiClass = psiHelper.getSurroundingClass(caretOffset)!! - val packageList = cutPsiClass.qualifiedName.split(".").dropLast(1) - val packageName = packageList.joinToString(".") +// val cutPsiClass = psiHelper.getSurroundingClass(caretOffset)!! +// val packageList = cutPsiClass.qualifiedName.split(".").dropLast(1) +// val packageName = packageList.joinToString(".") + val packageName = psiHelper.getPackageName() return Pipeline(project, psiHelper, caretOffset, fileUrl, packageName, testGenerationController) } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt index 2cc74298c..1196016b2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt @@ -1,37 +1,27 @@ package org.jetbrains.research.testspark.tools.llm.generation import com.intellij.openapi.diagnostic.Logger -import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.Language import org.jetbrains.research.testspark.core.test.TestSuiteParser import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.java.JavaJUnitTestSuiteParser -import org.jetbrains.research.testspark.core.test.kotlin.KotlinJUnitTestSuiteParser -import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy -import org.jetbrains.research.testspark.core.utils.javaImportPattern -import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState /** * Assembler class for generating and organizing test cases. * - * @property project The project to which the tests belong. * @property indicator The progress indicator to display the progress of test generation. * @property log The logger for logging debug information. * @property lastTestCount The count of the last generated tests. */ class JUnitTestsAssembler( - val project: Project, val indicator: CustomProgressIndicator, - val generationData: TestGenerationData, + private val generationData: TestGenerationData, + private val testSuiteParser: TestSuiteParser, + val junitVersion: JUnitVersion, ) : TestsAssembler() { - private val llmSettingsState: LLMSettingsState - get() = project.getService(LLMSettingsService::class.java).state private val log: Logger = Logger.getInstance(this.javaClass) @@ -59,11 +49,8 @@ class JUnitTestsAssembler( } } - override fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? { - val junitVersion = llmSettingsState.junitVersion - - val parser = createTestSuiteParser(packageName, junitVersion, language) - val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(super.getContent()) + override fun assembleTestSuite(): TestSuiteGeneratedByLLM? { + val testSuite = testSuiteParser.parseTestSuite(super.getContent()) // save RunWith if (testSuite?.runWith?.isNotBlank() == true) { @@ -81,16 +68,4 @@ class JUnitTestsAssembler( testSuite?.testCases?.forEach { testCase -> log.info("Generated test case: $testCase") } return testSuite } - - private fun createTestSuiteParser( - packageName: String, - jUnitVersion: JUnitVersion, - language: Language, - ): TestSuiteParser { - val parsingStrategy = JUnitTestSuiteParserStrategy() - return when (language) { - Language.Java -> JavaJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern, parsingStrategy) - Language.Kotlin -> KotlinJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern, parsingStrategy) - } - } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt index 89f07e0f7..1a394d04f 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt @@ -3,15 +3,19 @@ package org.jetbrains.research.testspark.tools.llm.generation import com.intellij.openapi.components.service import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project +import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.FeedbackCycleExecutionResult import org.jetbrains.research.testspark.core.generation.llm.LLMWithFeedbackCycle +import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeReductionStrategy import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.test.TestsPresenter import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM import org.jetbrains.research.testspark.data.FragmentToTestData @@ -20,8 +24,11 @@ import org.jetbrains.research.testspark.data.ProjectContext import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.services.PluginSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestBodyPrinterFactory +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor +import org.jetbrains.research.testspark.tools.TestSuiteParserFactory +import org.jetbrains.research.testspark.tools.TestsAssemblerFactory import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager @@ -42,19 +49,23 @@ import java.nio.file.Path */ class LLMProcessManager( private val project: Project, - private val language: Language, + private val language: SupportedLanguage, private val promptManager: PromptManager, private val testSamplesCode: String, - projectSDKPath: Path? = null, + private val projectSDKPath: Path? = null, ) : ProcessManager { - private val llmSettingsState: LLMSettingsState - get() = project.getService(LLMSettingsService::class.java).state - private val testFileName: String = "GeneratedTest.java" + private val homeDirectory = + projectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + + private val testFileName: String = when (language) { + SupportedLanguage.Java -> "GeneratedTest.java" + SupportedLanguage.Kotlin -> "GeneratedTest.kt" + } private val log = Logger.getInstance(this::class.java) private val llmErrorManager: LLMErrorManager = LLMErrorManager() private val maxRequests = LlmSettingsArguments(project).maxLLMRequest() - private val testProcessor = TestProcessor(project, projectSDKPath) + private val testProcessor: TestsPersistentStorage = TestProcessor(project, projectSDKPath) /** * Runs the test generator process. @@ -91,16 +102,16 @@ class LLMProcessManager( val report = IJReport() // PROMPT GENERATION - val initialPromptMessage = promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) - - val testCompiler = testProcessor.testCompiler + val initialPromptMessage = + promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) // initiate a new RequestManager val requestManager = StandardRequestManagerFactory(project).getRequestManager(project) // adapter for the existing prompt reduction functionality val promptSizeReductionStrategy = object : PromptSizeReductionStrategy { - override fun isReductionPossible(): Boolean = promptManager.isPromptSizeReductionPossible(generatedTestsData) + override fun isReductionPossible(): Boolean = + promptManager.isPromptSizeReductionPossible(generatedTestsData) override fun reduceSizeAndGeneratePrompt(): String { if (!isReductionPossible()) { @@ -115,7 +126,7 @@ class LLMProcessManager( // adapter for the existing test case/test suite string representing functionality val testsPresenter = object : TestsPresenter { - private val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData) + private val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) override fun representTestSuite(testSuite: TestSuiteGeneratedByLLM): String { return testSuitePresenter.toStringWithoutExpectedException(testSuite) @@ -126,6 +137,29 @@ class LLMProcessManager( } } + // Creation of JUnit specific parser, printer and assembler + val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + val testBodyPrinter = TestBodyPrinterFactory.create(language) + val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( + jUnitVersion, + language, + testBodyPrinter, + packageName, + ) + val testsAssembler = TestsAssemblerFactory.create( + indicator, + generatedTestsData, + testSuiteParser, + jUnitVersion, + ) + + val testCompiler = TestCompilerFactory.create( + project, + jUnitVersion, + language, + homeDirectory, + ) + // Asking LLM to generate a test suite. Here we have a feedback cycle for LLM in case of wrong responses val llmFeedbackCycle = LLMWithFeedbackCycle( language = language, @@ -137,7 +171,7 @@ class LLMProcessManager( resultPath = generatedTestsData.resultPath, buildPath = buildPath, requestManager = requestManager, - testsAssembler = JUnitTestsAssembler(project, indicator, generatedTestsData), + testsAssembler = testsAssembler, testCompiler = testCompiler, testStorage = testProcessor, testsPresenter = testsPresenter, @@ -150,8 +184,10 @@ class LLMProcessManager( when (warning) { LLMWithFeedbackCycle.WarningType.TEST_SUITE_PARSING_FAILED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("emptyResponse"), project) + LLMWithFeedbackCycle.WarningType.NO_TEST_CASES_GENERATED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("emptyResponse"), project) + LLMWithFeedbackCycle.WarningType.COMPILATION_ERROR_OCCURRED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("compilationError"), project) } @@ -167,17 +203,21 @@ class LLMProcessManager( // store compilable test cases generatedTestsData.compilableTestCases.addAll(feedbackResponse.compilableTestCases) } + FeedbackCycleExecutionResult.NO_COMPILABLE_TEST_CASES_GENERATED -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("invalidLLMResult"), project, errorMonitor) } + FeedbackCycleExecutionResult.CANCELED -> { log.info("Process stopped") return null } + FeedbackCycleExecutionResult.PROVIDED_PROMPT_TOO_LONG -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("tooLongPromptRequest"), project, errorMonitor) return null } + FeedbackCycleExecutionResult.SAVING_TEST_FILES_ISSUE -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("savingTestFileIssue"), project, errorMonitor) } @@ -190,7 +230,7 @@ class LLMProcessManager( log.info("Save generated test suite and test cases into the project workspace") - val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData) + val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) val generatedTestSuite: TestSuiteGeneratedByLLM? = feedbackResponse.generatedTestSuite val testSuiteRepresentation = if (generatedTestSuite != null) testSuitePresenter.toString(generatedTestSuite) else null @@ -200,10 +240,11 @@ class LLMProcessManager( ToolUtils.saveData( project, report, - ToolUtils.getPackageFromTestSuiteCode(testSuiteCode = testSuiteRepresentation), - ToolUtils.getImportsCodeFromTestSuiteCode(testSuiteRepresentation, projectContext.classFQN!!), + getPackageFromTestSuiteCode(testSuiteCode = testSuiteRepresentation, language), + getImportsCodeFromTestSuiteCode(testSuiteRepresentation, projectContext.classFQN), projectContext.fileUrlAsString!!, generatedTestsData, + language, ) return UIContext(projectContext, generatedTestsData, requestManager, errorMonitor) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt index d7ac8f9f5..0282f93c8 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt @@ -5,7 +5,6 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.openapi.util.Computable import com.intellij.openapi.util.TextRange -import com.intellij.psi.PsiDocumentManager import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.llm.LLMSettingsBundle import org.jetbrains.research.testspark.core.data.TestGenerationData @@ -15,7 +14,7 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptConfiguration import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptGenerationContext import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptTemplates -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.llm.JsonEncoding import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper @@ -39,6 +38,7 @@ class PromptManager( private val psiHelper: PsiHelper, private val caret: Int, ) { + // The classesToTest is empty when we work with the function outside the class private val classesToTest: List get() { val classesToTest = mutableListOf() @@ -52,7 +52,8 @@ class PromptManager( return classesToTest } - private val cut: PsiClassWrapper = classesToTest[0] + // The cut is null when we work with the function outside the class + private val cut: PsiClassWrapper? = if (classesToTest.isNotEmpty()) classesToTest[0] else null private val llmSettingsState: LLMSettingsState get() = project.getService(LLMSettingsService::class.java).state @@ -79,7 +80,7 @@ class PromptManager( .toMap() val context = PromptGenerationContext( - cut = createClassRepresentation(cut), + cut = cut?.let { createClassRepresentation(it) }, classesToTest = classesToTest.map(this::createClassRepresentation).toList(), polymorphismRelations = polymorphismRelations, promptConfiguration = PromptConfiguration( @@ -110,7 +111,7 @@ class PromptManager( .map(this::createClassRepresentation) .toList() - promptGenerator.generatePromptForMethod(method, interestingClassesFromMethod, testSamplesCode) + promptGenerator.generatePromptForMethod(method, interestingClassesFromMethod, testSamplesCode, psiHelper.getPackageName()) } CodeType.LINE -> { @@ -118,7 +119,7 @@ class PromptManager( val psiMethod = getPsiMethod(cut, getMethodDescriptor(cut, lineNumber))!! // get code of line under test - val document = PsiDocumentManager.getInstance(project).getDocument(cut.containingFile) + val document = psiHelper.getDocumentFromPsiFile() val lineStartOffset = document!!.getLineStartOffset(lineNumber - 1) val lineEndOffset = document.getLineEndOffset(lineNumber - 1) @@ -149,7 +150,7 @@ class PromptManager( signature = psiMethod.signature, name = psiMethod.name, text = psiMethod.text!!, - containingClassQualifiedName = psiMethod.containingClass!!.qualifiedName, + containingClassQualifiedName = psiMethod.containingClass?.qualifiedName ?: "", ) } @@ -210,7 +211,6 @@ class PromptManager( * * @param project The project context in which the PsiClasses exist. * @param interestingPsiClasses The set of PsiClassWrappers that are considered interesting. - * @param cutPsiClass The cut PsiClassWrapper to determine polymorphism relations against. * @return A mutable map where the key represents an interesting PsiClass and the value is a list of its detected subclasses. */ private fun getPolymorphismRelationsWithQualifiedNames( @@ -219,6 +219,9 @@ class PromptManager( ): MutableMap> { val polymorphismRelations: MutableMap> = mutableMapOf() + // assert(interestingPsiClasses.isEmpty()) + if (cut == null) return polymorphismRelations + interestingPsiClasses.add(cut) interestingPsiClasses.forEach { currentInterestingClass -> @@ -245,9 +248,15 @@ class PromptManager( * @return The matching PsiMethod if found, otherwise an empty string. */ private fun getPsiMethod( - psiClass: PsiClassWrapper, + psiClass: PsiClassWrapper?, methodDescriptor: String, ): PsiMethodWrapper? { + // Processing function outside the class + if (psiClass == null) { + val currentPsiMethod = psiHelper.getSurroundingMethod(caret)!! + if (psiHelper.generateMethodDescriptor(currentPsiMethod) == methodDescriptor) return currentPsiMethod + return null + } for (currentPsiMethod in psiClass.allMethods) { val file = psiClass.containingFile val psiHelper = PsiHelperProvider.getPsiHelper(file) @@ -268,9 +277,14 @@ class PromptManager( * @return the method descriptor as a String, or an empty string if no method is found */ private fun getMethodDescriptor( - psiClass: PsiClassWrapper, + psiClass: PsiClassWrapper?, lineNumber: Int, ): String { + // Processing function outside the class + if (psiClass == null) { + val currentPsiMethod = psiHelper.getSurroundingMethod(caret)!! + return psiHelper.generateMethodDescriptor(currentPsiMethod) + } for (currentPsiMethod in psiClass.allMethods) { if (currentPsiMethod.containsLine(lineNumber)) { val file = psiClass.containingFile diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt index b1473b0c9..10aded741 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt @@ -3,12 +3,14 @@ package org.jetbrains.research.testspark.tools.llm.test import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper +import org.jetbrains.research.testspark.tools.TestClassCodeGeneratorFactory class JUnitTestSuitePresenter( private val project: Project, private val generatedTestsData: TestGenerationData, + private val language: SupportedLanguage, ) { /** * Returns a string representation of this object. @@ -34,12 +36,12 @@ class JUnitTestSuitePresenter( // Add each test testCases.forEach { testCase -> testBody += "$testCase\n" } - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, testFileName, testBody, imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -57,12 +59,12 @@ class JUnitTestSuitePresenter( testCaseIndex: Int, ): String = testSuite.run { - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, getClassWithTestCaseName(testCases[testCaseIndex].name), testCases[testCaseIndex].toStringWithoutExpectedException() + "\n", imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -81,12 +83,12 @@ class JUnitTestSuitePresenter( // Add each test (exclude expected exception) testCases.forEach { testCase -> testBody += "${testCase.toStringWithoutExpectedException()}\n" } - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, testFileName, testBody, imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -105,8 +107,8 @@ class JUnitTestSuitePresenter( fun getPrintablePackageString(testSuite: TestSuiteGeneratedByLLM): String { return testSuite.run { when { - packageString.isEmpty() || packageString.isBlank() -> "" - else -> packageString + packageName.isEmpty() || packageName.isBlank() -> "" + else -> packageName } } } From d7611841abea1f5e0c2fadf2ba25391a8213c55c Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 20:33:50 +0200 Subject: [PATCH 05/19] TestClassBuilderHelper refactoring --- .../kotlin/KotlinPsiMethodWrapper.kt | 16 --- .../testspark/appstarter/TestSparkStarter.kt | 2 +- .../research/testspark/helpers/LLMHelper.kt | 4 +- .../helpers/TestClassBuilderHelper.kt | 57 ---------- .../helpers/TestClassCodeAnalyzer.kt | 39 +++++++ .../helpers/TestClassCodeGenerator.kt | 43 ++++++++ .../helpers/java/JavaTestClassCodeAnalyzer.kt | 78 +++++++++++++ ...elper.kt => JavaTestClassCodeGenerator.kt} | 80 +------------- .../kotlin/KotlinTestClassCodeAnalyzer.kt | 65 +++++++++++ ...per.kt => KotlinTestClassCodeGenerator.kt} | 62 +---------- .../java/JavaTestCaseDisplayService.kt | 7 +- .../testspark/tools/TestBodyPrinterFactory.kt | 2 +- .../tools/TestClassCodeAnalyzerFactory.kt | 21 ++++ .../tools/TestClassCodeGeneratorFactory.kt | 21 ++++ .../testspark/tools/TestsAssemblerFactory.kt | 2 +- .../research/testspark/tools/ToolUtils.kt | 35 ++---- .../tools/llm/generation/LLMProcessManager.kt | 8 +- .../tools/llm/test/JUnitTestSuitePresenter.kt | 103 ++++++------------ 18 files changed, 330 insertions(+), 315 deletions(-) delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassBuilderHelper.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt rename src/main/kotlin/org/jetbrains/research/testspark/helpers/java/{JavaClassBuilderHelper.kt => JavaTestClassCodeGenerator.kt} (50%) create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt rename src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/{KotlinClassBuilderHelper.kt => KotlinTestClassCodeGenerator.kt} (58%) create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt index 3571339d1..c993fd808 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt @@ -68,22 +68,6 @@ class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper { return lineNumber in startLine..endLine } - fun getInterestingPsiClassesWithQualifiedNames(): MutableSet { - val interestingPsiClasses = mutableSetOf() - - psiFunction.valueParameters.forEach { parameter -> - val typeReference = parameter.typeReference - if (typeReference != null) { - val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) - if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { - interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) - } - } - } - - return interestingPsiClasses - } - /** * Returns a set of `PsiClassWrapper` instances for non-standard Kotlin classes referenced by the * parameters of the current function. diff --git a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt index 5274c0269..499abf1c1 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt @@ -174,7 +174,7 @@ class TestSparkStarter : ApplicationStarter { // Start test generation val indicator = HeadlessProgressIndicator() val errorMonitor = DefaultErrorMonitor() - val testCompiler = TestCompilerFactory.createTestCompiler( + val testCompiler = TestCompilerFactory.create( project, settingsState.junitVersion, psiHelper.language, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt index 3471fdbcf..d10525087 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt @@ -262,14 +262,14 @@ object LLMHelper { } val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion - val testBodyPrinter = TestBodyPrinterFactory.createTestBodyPrinter(language) + val testBodyPrinter = TestBodyPrinterFactory.create(language) val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( jUnitVersion, language, testBodyPrinter, ) - val testsAssembler = TestsAssemblerFactory.createTestsAssembler( + val testsAssembler = TestsAssemblerFactory.create( indicator, testGenerationOutput, testSuiteParser, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassBuilderHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassBuilderHelper.kt deleted file mode 100644 index 265fca954..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassBuilderHelper.kt +++ /dev/null @@ -1,57 +0,0 @@ -package org.jetbrains.research.testspark.helpers - -import com.intellij.openapi.project.Project -import org.jetbrains.research.testspark.core.data.TestGenerationData - -interface TestClassBuilderHelper { - /** - * Generates the code for a test class. - * - * @param className the name of the test class - * @param body the body of the test class - * @return the generated code as a string - */ - fun generateCode( - project: Project, - className: String, - body: String, - imports: Set, - packageString: String, - runWith: String, - otherInfo: String, - testGenerationData: TestGenerationData, - ): String - - /** - * Extracts the code of the first test method found in the given class code. - * - * @param classCode The code of the class containing test methods. - * @return The code of the first test method as a string, including the "@Test" annotation. - */ - fun extractFirstTestMethodCode(classCode: String): String - - /** - * Retrieves the name of the first test method found in the given class code. - * - * @param oldTestCaseName The old name of test case - * @param classCode The source code of the class containing test methods. - * @return The name of the first test method. If no test method is found, an empty string is returned. - */ - fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String - - /** - * Retrieves the class name from the given test case code. - * - * @param code The test case code to extract the class name from. - * @return The class name extracted from the test case code. - */ - fun getClassFromTestCaseCode(code: String): String - - /** - * Formats the given Java code using IntelliJ IDEA's code formatting rules. - * - * @param code The Java code to be formatted. - * @return The formatted Java code. - */ - fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt new file mode 100644 index 000000000..b20891ed4 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt @@ -0,0 +1,39 @@ +package org.jetbrains.research.testspark.helpers + +/** + * Interface for retrieving information from test class code. + */ +interface TestClassCodeAnalyzer { + /** + * Extracts the code of the first test method found in the given class code. + * + * @param classCode The code of the class containing test methods. + * @return The code of the first test method as a string, including the "@Test" annotation. + */ + fun extractFirstTestMethodCode(classCode: String): String + + /** + * Retrieves the name of the first test method found in the given class code. + * + * @param oldTestCaseName The old name of a test case + * @param classCode The source code of the class containing test methods. + * @return The name of the first test method. If no test method is found, an empty string is returned. + */ + fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String + + /** + * Retrieves the class name from the given test case code. + * + * @param code the test case code to extract the class name from + * @return the class name extracted from the test case code + */ + fun getClassFromTestCaseCode(code: String): String + + /** + * Return the right file name from the given test case code. + * + * @param code the test case code to extract the class name from + * @return the class name extracted from the test case code + */ + fun getFileNameFromTestCaseCode(code: String): String +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt new file mode 100644 index 000000000..7443b1664 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt @@ -0,0 +1,43 @@ +package org.jetbrains.research.testspark.helpers + +import com.intellij.openapi.project.Project +import org.jetbrains.research.testspark.core.data.TestGenerationData + +/** + * Interface for generating and formatting test class code. + */ +interface TestClassCodeGenerator { + /** + * Generates the code for a test class. + * + * @param project the current project + * @param className the name of the test class + * @param body the body of the test class + * @param imports the set of imports needed in the test class + * @param packageString the package declaration of the test class + * @param runWith the runWith annotation for the test class + * @param otherInfo any other additional information for the test class + * @param testGenerationData the data used for test generation + * @return the generated code as a string + */ + fun generateCode( + project: Project, + className: String, + body: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + testGenerationData: TestGenerationData, + ): String + + /** + * Formats the given Java code using IntelliJ IDEA's code formatting rules. + * + * @param project the current project + * @param code the Java code to be formatted + * @param generatedTestData the data used for generating the test + * @return the formatted Java code + */ + fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt new file mode 100644 index 000000000..f6f2fd0a9 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt @@ -0,0 +1,78 @@ +package org.jetbrains.research.testspark.helpers.java + +import com.github.javaparser.ParseProblemException +import com.github.javaparser.StaticJavaParser +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.body.MethodDeclaration +import com.github.javaparser.ast.visitor.VoidVisitorAdapter +import org.jetbrains.research.testspark.helpers.TestClassCodeAnalyzer + +object JavaTestClassCodeAnalyzer : TestClassCodeAnalyzer { + + override fun extractFirstTestMethodCode(classCode: String): String { + var result = "" + try { + val componentUnit: CompilationUnit = StaticJavaParser.parse(classCode) + object : VoidVisitorAdapter() { + override fun visit(method: MethodDeclaration, arg: Any?) { + super.visit(method, arg) + if (method.getAnnotationByName("Test").isPresent) { + result += "\t" + method.toString().replace("\n", "\n\t") + "\n\n" + } + } + }.visit(componentUnit, null) + + return result + } catch (e: ParseProblemException) { + val upperCutCode = "\t@Test" + classCode.split("@Test").last() + var methodStarted = false + var balanceOfBrackets = 0 + for (symbol in upperCutCode) { + result += symbol + if (symbol == '{') { + methodStarted = true + balanceOfBrackets++ + } + if (symbol == '}') { + balanceOfBrackets-- + } + if (methodStarted && balanceOfBrackets == 0) { + break + } + } + return result + "\n" + } + } + + override fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String { + var result = "" + try { + val componentUnit: CompilationUnit = StaticJavaParser.parse(classCode) + + object : VoidVisitorAdapter() { + override fun visit(method: MethodDeclaration, arg: Any?) { + super.visit(method, arg) + if (method.getAnnotationByName("Test").isPresent) { + result = method.nameAsString + } + } + }.visit(componentUnit, null) + + return result + } catch (e: ParseProblemException) { + return oldTestCaseName + } + } + + override fun getClassFromTestCaseCode(code: String): String { + val pattern = Regex("public\\s+class\\s+(\\S+)\\s*\\{") + val matchResult = pattern.find(code) + matchResult ?: return "GeneratedTest" + val (className) = matchResult.destructured + return className + } + + override fun getFileNameFromTestCaseCode(code: String): String { + return "${getClassFromTestCaseCode(code)}.java" + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaClassBuilderHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt similarity index 50% rename from src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaClassBuilderHelper.kt rename to src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt index e86340a0d..46c071d5f 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaClassBuilderHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt @@ -1,22 +1,20 @@ package org.jetbrains.research.testspark.helpers.java -import com.github.javaparser.ParseProblemException -import com.github.javaparser.StaticJavaParser -import com.github.javaparser.ast.CompilationUnit -import com.github.javaparser.ast.body.MethodDeclaration -import com.github.javaparser.ast.visitor.VoidVisitorAdapter import com.intellij.lang.java.JavaLanguage import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiFile import com.intellij.psi.PsiFileFactory import com.intellij.psi.codeStyle.CodeStyleManager import org.jetbrains.research.testspark.core.data.TestGenerationData -import org.jetbrains.research.testspark.helpers.TestClassBuilderHelper +import org.jetbrains.research.testspark.helpers.TestClassCodeGenerator import java.io.File -object JavaClassBuilderHelper : TestClassBuilderHelper { +object JavaTestClassCodeGenerator : TestClassCodeGenerator { + + private val log = Logger.getInstance(this::class.java) override fun generateCode( project: Project, @@ -45,69 +43,6 @@ object JavaClassBuilderHelper : TestClassBuilderHelper { return formatCode(project, Regex("\n\n\n(?:\n)*").replace(testFullText, "\n\n"), testGenerationData) } - override fun extractFirstTestMethodCode(classCode: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(classCode) - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result += "\t" + method.toString().replace("\n", "\n\t") + "\n\n" - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - val upperCutCode = "\t@Test" + classCode.split("@Test").last() - var methodStarted = false - var balanceOfBrackets = 0 - for (symbol in upperCutCode) { - result += symbol - if (symbol == '{') { - methodStarted = true - balanceOfBrackets++ - } - if (symbol == '}') { - balanceOfBrackets-- - } - if (methodStarted && balanceOfBrackets == 0) { - break - } - } - return result + "\n" - } - } - - override fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(classCode) - - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result = method.nameAsString - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - return oldTestCaseName - } - } - - override fun getClassFromTestCaseCode(code: String): String { - val pattern = Regex("public\\s+class\\s+(\\S+)\\s*\\{") - val matchResult = pattern.find(code) - matchResult ?: return "GeneratedTest" - val (className) = matchResult.destructured - return className - } - override fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String { var result = "" WriteCommandAction.runWriteCommandAction(project) { @@ -131,11 +66,6 @@ object JavaClassBuilderHelper : TestClassBuilderHelper { return result } - /** - * Returns the upper part of test suite (package name, imports, and test class name) as a string. - * - * @return the upper part of test suite (package name, imports, and test class name) as a string. - */ private fun printUpperPart( className: String, imports: Set, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt new file mode 100644 index 000000000..b21a97dfd --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt @@ -0,0 +1,65 @@ +package org.jetbrains.research.testspark.helpers.kotlin + +import org.jetbrains.research.testspark.helpers.TestClassCodeAnalyzer + +object KotlinTestClassCodeAnalyzer : TestClassCodeAnalyzer { + + override fun extractFirstTestMethodCode(classCode: String): String { + val testMethods = StringBuilder() + val lines = classCode.lines() + + var methodStarted = false + var balanceOfBrackets = 0 + + for (line in lines) { + if (!methodStarted && line.contains("@Test")) { + methodStarted = true + testMethods.append(line).append("\n") + } else if (methodStarted) { + testMethods.append(line).append("\n") + for (char in line) { + if (char == '{') { + balanceOfBrackets++ + } else if (char == '}') { + balanceOfBrackets-- + } + } + if (balanceOfBrackets == 0) { + methodStarted = false + testMethods.append("\n") + } + } + } + + return testMethods.toString() + } + + override fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String { + val lines = classCode.lines() + var testMethodName = oldTestCaseName + + for (line in lines) { + if (line.contains("@Test")) { + val methodDeclarationLine = lines[lines.indexOf(line) + 1] + val matchResult = Regex("fun\\s+(\\w+)\\s*\\(").find(methodDeclarationLine) + if (matchResult != null) { + testMethodName = matchResult.groupValues[1] + } + break + } + } + return testMethodName + } + + override fun getClassFromTestCaseCode(code: String): String { + val pattern = Regex("class\\s+(\\S+)\\s*\\{") + val matchResult = pattern.find(code) + matchResult ?: return "GeneratedTest" + val (className) = matchResult.destructured + return className + } + + override fun getFileNameFromTestCaseCode(code: String): String { + return "${getClassFromTestCaseCode(code)}.kt" + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinClassBuilderHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt similarity index 58% rename from src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinClassBuilderHelper.kt rename to src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt index 93b0e7be1..eb10a7aa9 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinClassBuilderHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt @@ -9,10 +9,10 @@ import com.intellij.psi.PsiFileFactory import com.intellij.psi.codeStyle.CodeStyleManager import org.jetbrains.kotlin.idea.KotlinLanguage import org.jetbrains.research.testspark.core.data.TestGenerationData -import org.jetbrains.research.testspark.helpers.TestClassBuilderHelper +import org.jetbrains.research.testspark.helpers.TestClassCodeGenerator import java.io.File -object KotlinClassBuilderHelper : TestClassBuilderHelper { +object KotlinTestClassCodeGenerator : TestClassCodeGenerator { private val log = Logger.getInstance(this::class.java) @@ -28,7 +28,8 @@ object KotlinClassBuilderHelper : TestClassBuilderHelper { ): String { log.debug("[KotlinClassBuilderHelper] Generate code for $className") - var testFullText = printUpperPart(className, imports, packageString, runWith, otherInfo) + var testFullText = + printUpperPart(className, imports, packageString, runWith, otherInfo) // Add each test (exclude expected exception) testFullText += body @@ -42,61 +43,6 @@ object KotlinClassBuilderHelper : TestClassBuilderHelper { return formatCode(project, Regex("\n\n\n(?:\n)*").replace(testFullText, "\n\n"), testGenerationData) } - override fun extractFirstTestMethodCode(classCode: String): String { - val testMethods = StringBuilder() - val lines = classCode.lines() - - var methodStarted = false - var balanceOfBrackets = 0 - - for (line in lines) { - if (!methodStarted && line.contains("@Test")) { - methodStarted = true - testMethods.append(line).append("\n") - } else if (methodStarted) { - testMethods.append(line).append("\n") - for (char in line) { - if (char == '{') { - balanceOfBrackets++ - } else if (char == '}') { - balanceOfBrackets-- - } - } - if (balanceOfBrackets == 0) { - methodStarted = false - testMethods.append("\n") - } - } - } - - return testMethods.toString() - } - - override fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String { - val lines = classCode.lines() - var testMethodName = oldTestCaseName - - for (line in lines) { - if (line.contains("@Test")) { - val methodDeclarationLine = lines[lines.indexOf(line) + 1] - val matchResult = Regex("fun\\s+(\\w+)\\s*\\(").find(methodDeclarationLine) - if (matchResult != null) { - testMethodName = matchResult.groupValues[1] - } - break - } - } - return testMethodName - } - - override fun getClassFromTestCaseCode(code: String): String { - val pattern = Regex("class\\s+(\\S+)\\s*\\{") - val matchResult = pattern.find(code) - matchResult ?: return "GeneratedTest" - val (className) = matchResult.destructured - return className - } - override fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String { var result = "" WriteCommandAction.runWriteCommandAction(project) { diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt index 5a910a13a..0dbc5009c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt @@ -37,7 +37,8 @@ import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.display.TestCasePanelFactory import org.jetbrains.research.testspark.display.TopButtonsPanelFactory import org.jetbrains.research.testspark.helpers.ReportHelper -import org.jetbrains.research.testspark.helpers.java.JavaClassBuilderHelper +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeGenerator import org.jetbrains.research.testspark.java.JavaPsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.services.CoverageVisualisationService @@ -413,8 +414,8 @@ class JavaTestCaseDisplayService(private val project: Project) : TestCaseDisplay // insert tests to a code testCaseComponents.reversed().forEach { val testMethodCode = - JavaClassBuilderHelper.extractFirstTestMethodCode( - JavaClassBuilderHelper.formatCode( + JavaTestClassCodeAnalyzer.extractFirstTestMethodCode( + JavaTestClassCodeGenerator.formatCode( project, it.replace("\r\n", "\n") .replace("verifyException(", "// verifyException("), diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt index d47bfaa75..ea0c0bc2e 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt @@ -7,7 +7,7 @@ import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestBodyPrinter class TestBodyPrinterFactory { companion object { - fun createTestBodyPrinter(language: SupportedLanguage): TestBodyPrinter { + fun create(language: SupportedLanguage): TestBodyPrinter { return when (language) { SupportedLanguage.Kotlin -> KotlinTestBodyPrinter() SupportedLanguage.Java -> JavaTestBodyPrinter() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt new file mode 100644 index 000000000..1b73c380c --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt @@ -0,0 +1,21 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.helpers.TestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeAnalyzer + +object TestClassCodeAnalyzerFactory { + /** + * Creates an instance of TestClassCodeAnalyzer for the specified language. + * + * @param language the programming language for which to create the analyzer + * @return an instance of TestClassCodeAnalyzer + */ + fun create(language: SupportedLanguage): TestClassCodeAnalyzer { + return when (language) { + SupportedLanguage.Kotlin -> KotlinTestClassCodeAnalyzer + SupportedLanguage.Java -> JavaTestClassCodeAnalyzer + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt new file mode 100644 index 000000000..56151e26e --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt @@ -0,0 +1,21 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.helpers.TestClassCodeGenerator +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeGenerator +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeGenerator + +object TestClassCodeGeneratorFactory { + /** + * Creates an instance of TestClassCodeGenerator for the specified language. + * + * @param language the programming language for which to create the generator + * @return an instance of TestClassCodeGenerator + */ + fun create(language: SupportedLanguage): TestClassCodeGenerator { + return when (language) { + SupportedLanguage.Kotlin -> KotlinTestClassCodeGenerator + SupportedLanguage.Java -> JavaTestClassCodeGenerator + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt index d096af21f..a896d273c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt @@ -8,7 +8,7 @@ import org.jetbrains.research.testspark.tools.llm.generation.JUnitTestsAssembler class TestsAssemblerFactory { companion object { - fun createTestsAssembler( + fun create( indicator: CustomProgressIndicator, generationData: TestGenerationData, testSuiteParser: TestSuiteParser, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt index 3e8eb70a0..a7ef25eb2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt @@ -14,8 +14,6 @@ import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.utils.DataFilesUtil import org.jetbrains.research.testspark.data.IJTestCase -import org.jetbrains.research.testspark.helpers.java.JavaClassBuilderHelper -import org.jetbrains.research.testspark.helpers.kotlin.KotlinClassBuilderHelper import org.jetbrains.research.testspark.services.TestsExecutionResultService import java.io.File @@ -48,29 +46,16 @@ object ToolUtils { for (testCase in report.testCaseList.values) { val code = testCase.testCode - testCase.testCode = when (language) { - SupportedLanguage.Java -> JavaClassBuilderHelper.generateCode( - project, - getClassWithTestCaseName(testCase.testName), - code, - generatedTestData.importsCode, - generatedTestData.packageName, - generatedTestData.runWith, - generatedTestData.otherInfo, - generatedTestData, - ) - - SupportedLanguage.Kotlin -> KotlinClassBuilderHelper.generateCode( - project, - getClassWithTestCaseName(testCase.testName), - code, - generatedTestData.importsCode, - generatedTestData.packageName, - generatedTestData.runWith, - generatedTestData.otherInfo, - generatedTestData, - ) - } + testCase.testCode = TestClassCodeGeneratorFactory.create(language).generateCode( + project, + getClassWithTestCaseName(testCase.testName), + code, + generatedTestData.importsCode, + generatedTestData.packageName, + generatedTestData.runWith, + generatedTestData.otherInfo, + generatedTestData, + ) } generatedTestData.testGenerationResultList.add(report) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt index edaffa60d..f46dd5603 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt @@ -138,21 +138,21 @@ class LLMProcessManager( // Creation of JUnit specific parser, printer and assembler val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion - val testBodyPrinter = TestBodyPrinterFactory.createTestBodyPrinter(language) + val testBodyPrinter = TestBodyPrinterFactory.create(language) val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( jUnitVersion, language, testBodyPrinter, - packageName + packageName, ) - val testsAssembler = TestsAssemblerFactory.createTestsAssembler( + val testsAssembler = TestsAssemblerFactory.create( indicator, generatedTestsData, testSuiteParser, jUnitVersion, ) - val testCompiler = TestCompilerFactory.createTestCompiler( + val testCompiler = TestCompilerFactory.create( project, jUnitVersion, language, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt index 91e201db6..10aded741 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt @@ -5,8 +5,7 @@ import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.helpers.java.JavaClassBuilderHelper -import org.jetbrains.research.testspark.helpers.kotlin.KotlinClassBuilderHelper +import org.jetbrains.research.testspark.tools.TestClassCodeGeneratorFactory class JUnitTestSuitePresenter( private val project: Project, @@ -37,29 +36,16 @@ class JUnitTestSuitePresenter( // Add each test testCases.forEach { testCase -> testBody += "$testCase\n" } - when (language) { - SupportedLanguage.Java -> JavaClassBuilderHelper.generateCode( - project, - testFileName, - testBody, - imports, - packageName, - runWith, - otherInfo, - generatedTestsData, - ) - - SupportedLanguage.Kotlin -> KotlinClassBuilderHelper.generateCode( - project, - testFileName, - testBody, - imports, - packageName, - runWith, - otherInfo, - generatedTestsData, - ) - } + TestClassCodeGeneratorFactory.create(language).generateCode( + project, + testFileName, + testBody, + imports, + packageName, + runWith, + otherInfo, + generatedTestsData, + ) } } @@ -73,29 +59,16 @@ class JUnitTestSuitePresenter( testCaseIndex: Int, ): String = testSuite.run { - when (language) { - SupportedLanguage.Java -> JavaClassBuilderHelper.generateCode( - project, - getClassWithTestCaseName(testCases[testCaseIndex].name), - testCases[testCaseIndex].toStringWithoutExpectedException() + "\n", - imports, - packageName, - runWith, - otherInfo, - generatedTestsData, - ) - - SupportedLanguage.Kotlin -> KotlinClassBuilderHelper.generateCode( - project, - getClassWithTestCaseName(testCases[testCaseIndex].name), - testCases[testCaseIndex].toStringWithoutExpectedException() + "\n", - imports, - packageName, - runWith, - otherInfo, - generatedTestsData, - ) - } + TestClassCodeGeneratorFactory.create(language).generateCode( + project, + getClassWithTestCaseName(testCases[testCaseIndex].name), + testCases[testCaseIndex].toStringWithoutExpectedException() + "\n", + imports, + packageName, + runWith, + otherInfo, + generatedTestsData, + ) } /** @@ -110,30 +83,16 @@ class JUnitTestSuitePresenter( // Add each test (exclude expected exception) testCases.forEach { testCase -> testBody += "${testCase.toStringWithoutExpectedException()}\n" } - when (language) { - SupportedLanguage.Java -> - JavaClassBuilderHelper.generateCode( - project, - testFileName, - testBody, - imports, - packageName, - runWith, - otherInfo, - generatedTestsData, - ) - - SupportedLanguage.Kotlin -> KotlinClassBuilderHelper.generateCode( - project, - testFileName, - testBody, - imports, - packageName, - runWith, - otherInfo, - generatedTestsData, - ) - } + TestClassCodeGeneratorFactory.create(language).generateCode( + project, + testFileName, + testBody, + imports, + packageName, + runWith, + otherInfo, + generatedTestsData, + ) } } From e6afcb6e1d3ffca09553e407249c8acc14533bfd Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 20:41:25 +0200 Subject: [PATCH 06/19] renaming the getSurroundingLineNumber function --- build.gradle.kts | 30 +- .../testspark/core/data/TestGenerationData.kt | 4 +- .../generation/llm/LLMWithFeedbackCycle.kt | 14 +- .../testspark/core/generation/llm/Utils.kt | 48 +- .../generation/llm/network/RequestManager.kt | 12 +- .../generation/llm/prompt/PromptBuilder.kt | 13 +- .../generation/llm/prompt/PromptGenerator.kt | 6 +- .../llm/prompt/configuration/Configuration.kt | 5 +- .../testspark/core/test/SupportedLanguage.kt | 8 + .../testspark/core/test/TestBodyPrinter.kt | 21 + .../testspark/core/test/TestCompiler.kt | 60 +- .../test/{parsers => }/TestSuiteParser.kt | 4 +- .../testspark/core/test/TestsAssembler.kt | 8 +- .../core/test/TestsPersistentStorage.kt | 1 + .../testspark/core/test}/data/CodeType.kt | 2 +- .../core/test/data/TestCaseGeneratedByLLM.kt | 29 +- .../core/test/data/TestSuiteGeneratedByLLM.kt | 4 +- ...cies.kt => TestCompilationDependencies.kt} | 2 +- .../test/java/JavaJUnitTestSuiteParser.kt | 32 + .../core/test/java/JavaTestBodyPrinter.kt | 40 ++ .../core/test/java/JavaTestCompiler.kt | 53 ++ .../test/kotlin/KotlinJUnitTestSuiteParser.kt | 32 + .../core/test/kotlin/KotlinTestBodyPrinter.kt | 40 ++ .../core/test/kotlin/KotlinTestCompiler.kt | 31 + .../parsers/java/JavaJUnitTestSuiteParser.kt | 22 - .../kotlin/KotlinJUnitTestSuiteParser.kt | 22 - .../JUnitTestSuiteParserStrategy.kt | 173 ------ .../JUnitTestSuiteParserStrategy.kt | 175 ++++++ .../research/testspark/core/utils/Language.kt | 8 - .../research/testspark/core/utils/Patterns.kt | 10 +- .../kotlin/KotlinJUnitTestSuiteParserTest.kt | 161 ++++- .../testspark/java/JavaPsiClassWrapper.kt | 32 +- .../research/testspark/java/JavaPsiHelper.kt | 51 +- .../resources/META-INF/testspark-java.xml | 9 - .../testspark/kotlin/KotlinPsiClassWrapper.kt | 40 +- .../testspark/kotlin/KotlinPsiHelper.kt | 59 +- .../kotlin/KotlinPsiMethodWrapper.kt | 20 + .../resources/META-INF/testspark-kotlin.xml | 8 - langwrappers/build.gradle.kts | 2 - .../LanguageClassTextExtractor.kt | 7 + .../testspark/langwrappers/PsiComponents.kt | 40 +- .../JavaKotlinClassTextExtractor.kt | 39 ++ .../testspark/actions/TestSparkAction.kt | 87 +-- .../actions/llm/LLMSampleSelectorFactory.kt | 7 +- .../actions/llm/LLMSetupPanelFactory.kt | 6 +- .../actions/llm/TestSamplePanelFactory.kt | 4 +- .../testspark/appstarter/TestSparkStarter.kt | 14 +- .../testspark/data/FragmentToTestData.kt | 2 + .../testspark/display/TestCasePanelFactory.kt | 53 +- .../display/TopButtonsPanelFactory.kt | 70 +-- .../strategies/TopButtonsPanelStrategy.kt | 138 +++++ .../testspark/helpers/CoverageHelper.kt | 6 +- .../helpers/JavaClassBuilderHelper.kt | 204 ------- .../research/testspark/helpers/LLMHelper.kt | 55 +- .../helpers/TestClassCodeAnalyzer.kt | 39 ++ .../helpers/TestClassCodeGenerator.kt | 43 ++ .../helpers/java/JavaTestClassCodeAnalyzer.kt | 78 +++ .../java/JavaTestClassCodeGenerator.kt | 104 ++++ .../kotlin/KotlinTestClassCodeAnalyzer.kt | 65 ++ .../kotlin/KotlinTestClassCodeGenerator.kt | 101 ++++ .../CoverageToolWindowDisplayService.kt | 0 .../services/TestCaseDisplayService.kt | 527 +---------------- .../java/JavaTestCaseDisplayService.kt | 544 +++++++++++++++++ .../kotlin/KotlinTestCaseDisplayService.kt | 553 ++++++++++++++++++ .../settings/llm/LLMSettingsComponent.kt | 2 +- .../settings/llm/LLMSettingsConfigurable.kt | 12 + .../settings/llm/LLMSettingsState.kt | 6 + .../testspark/tools/LibraryPathsProvider.kt | 4 +- .../research/testspark/tools/Pipeline.kt | 45 +- .../testspark/tools/TestBodyPrinterFactory.kt | 17 + .../tools/TestClassCodeAnalyzerFactory.kt | 21 + .../tools/TestClassCodeGeneratorFactory.kt | 21 + .../testspark/tools/TestCompilerFactory.kt | 17 +- .../research/testspark/tools/TestProcessor.kt | 61 +- .../testspark/tools/TestSuiteParserFactory.kt | 31 + .../testspark/tools/TestsAssemblerFactory.kt | 18 + .../research/testspark/tools/ToolUtils.kt | 45 +- .../testspark/tools/evosuite/EvoSuite.kt | 2 +- .../evosuite/EvoSuiteSettingsArguments.kt | 2 +- .../generation/EvoSuiteProcessManager.kt | 9 +- .../research/testspark/tools/llm/Llm.kt | 12 +- .../tools/llm/LlmSettingsArguments.kt | 4 +- .../llm/generation/JUnitTestsAssembler.kt | 35 +- .../tools/llm/generation/LLMProcessManager.kt | 76 ++- .../tools/llm/generation/PromptManager.kt | 44 +- .../llm/generation/RequestManagerFactory.kt | 2 + .../generation/grazie/GrazieRequestManager.kt | 6 +- .../llm/generation/hf/HuggingFacePlatform.kt | 9 + .../generation/hf/HuggingFaceRequestBody.kt | 33 ++ .../hf/HuggingFaceRequestManager.kt | 116 ++++ .../generation/openai/OpenAIRequestBody.kt | 25 +- .../generation/openai/OpenAIRequestManager.kt | 29 +- .../tools/llm/test/JUnitTestSuitePresenter.kt | 20 +- src/main/resources/META-INF/plugin.xml | 2 +- .../properties/llm/LLMDefaults.properties | 4 + .../properties/llm/LLMMessages.properties | 3 +- .../SettingsArgumentsLlmEvoSuiteTest.kt | 2 +- 97 files changed, 3309 insertions(+), 1503 deletions(-) create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt rename core/src/main/kotlin/org/jetbrains/research/testspark/core/test/{parsers => }/TestSuiteParser.kt (87%) rename {src/main/kotlin/org/jetbrains/research/testspark => core/src/main/kotlin/org/jetbrains/research/testspark/core/test}/data/CodeType.kt (73%) rename core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/{JavaTestCompilationDependencies.kt => TestCompilationDependencies.kt} (96%) create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt delete mode 100644 java/src/main/resources/META-INF/testspark-java.xml delete mode 100644 kotlin/src/main/resources/META-INF/testspark-kotlin.xml create mode 100644 langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt create mode 100644 langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/display/strategies/TopButtonsPanelStrategy.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestSuiteParserFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFacePlatform.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt diff --git a/build.gradle.kts b/build.gradle.kts index b83f7e6bd..5e6621e29 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,6 +1,5 @@ import org.jetbrains.changelog.markdownToHTML import org.jetbrains.intellij.tasks.RunIdeTask -import org.jetbrains.intellij.tasks.RunPluginVerifierTask import org.jetbrains.kotlin.gradle.tasks.KotlinCompile import java.io.FileOutputStream import java.net.URL @@ -158,6 +157,7 @@ dependencies { // https://mvnrepository.com/artifact/org.mockito/mockito-all testImplementation("org.mockito:mockito-all:1.10.19") + testImplementation("org.mockito.kotlin:mockito-kotlin:5.1.0") // https://mvnrepository.com/artifact/net.jqwik/jqwik testImplementation("net.jqwik:jqwik:1.6.5") @@ -209,15 +209,6 @@ tasks { dependsOn(":core:compileKotlin") } - verifyPlugin { - dependsOn(":copyPluginAssets") - onlyIf { this.project == rootProject } - } - - runIde { - onlyIf { this.project == rootProject } - } - // Set the JVM compatibility versions properties("javaVersion").let { withType { @@ -296,25 +287,6 @@ tasks { // https://plugins.jetbrains.com/docs/intellij/deployment.html#specifying-a-release-channel channels.set(listOf(properties("pluginVersion").split('-').getOrElse(1) { "default" }.split('.').first())) } - - withType { - onlyIf { this.project == rootProject } - mustRunAfter("check") - - // 1.365 is broken, -// remove this version as soon as https://youtrack.jetbrains.com/issue/MP-6438 is fixed. -// verifierVersion.set("1.364") - ideVersions.set(properties("ideVersionVerifier").split(",")) - failureLevel.set( - listOf( - RunPluginVerifierTask.FailureLevel.INTERNAL_API_USAGES, - RunPluginVerifierTask.FailureLevel.COMPATIBILITY_PROBLEMS, - RunPluginVerifierTask.FailureLevel.OVERRIDE_ONLY_API_USAGES, - RunPluginVerifierTask.FailureLevel.NON_EXTENDABLE_API_USAGES, - RunPluginVerifierTask.FailureLevel.PLUGIN_STRUCTURE_WARNINGS, - ) - ) - } } abstract class CopyJUnitRunnerLib : DefaultTask() { diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt index d11f346d5..a35212cb1 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt @@ -16,7 +16,7 @@ data class TestGenerationData( // Code required of imports and package for generated tests var importsCode: MutableSet = mutableSetOf(), - var packageLine: String = "", + var packageName: String = "", var runWith: String = "", var otherInfo: String = "", @@ -37,7 +37,7 @@ data class TestGenerationData( resultName = "" fileUrl = "" importsCode = mutableSetOf() - packageLine = "" + packageName = "" runWith = "" otherInfo = "" polyDepthReducing = 0 diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt index 0c8a428aa..973b26e7a 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt @@ -10,13 +10,13 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeRed import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.test.TestsPresenter import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language import java.io.File enum class FeedbackCycleExecutionResult { @@ -45,7 +45,7 @@ data class FeedbackResponse( class LLMWithFeedbackCycle( private val report: Report, - private val language: Language, + private val language: SupportedLanguage, private val initialPromptMessage: String, private val promptSizeReductionStrategy: PromptSizeReductionStrategy, // filename in which the test suite is saved in result path @@ -167,13 +167,15 @@ class LLMWithFeedbackCycle( generatedTestSuite.updateTestCases(compilableTestCases.toMutableList()) } else { for (testCaseIndex in generatedTestSuite.testCases.indices) { - val testCaseFilename = - "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" + val testCaseFilename = when (language) { + SupportedLanguage.Java -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" + SupportedLanguage.Kotlin -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.kt" + } val testCaseRepresentation = testsPresenter.representTestCase(generatedTestSuite, testCaseIndex) val saveFilepath = testStorage.saveGeneratedTest( - generatedTestSuite.packageString, + generatedTestSuite.packageName, testCaseRepresentation, resultPath, testCaseFilename, @@ -184,7 +186,7 @@ class LLMWithFeedbackCycle( } val generatedTestSuitePath: String = testStorage.saveGeneratedTest( - generatedTestSuite.packageString, + generatedTestSuite.packageName, testsPresenter.representTestSuite(generatedTestSuite), resultPath, testSuiteFilename, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt index 76cb74c17..1942a6a86 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt @@ -4,13 +4,47 @@ import org.jetbrains.research.testspark.core.generation.llm.network.RequestManag import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.utils.javaPackagePattern +import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern import java.util.Locale // TODO: find a better place for the below functions +/** + * Retrieves the package declaration from the given test suite code for any language. + * + * @param testSuiteCode The generated code of the test suite. + * @return The package name extracted from the test suite code, or an empty string if no package declaration was found. + */ +fun getPackageFromTestSuiteCode(testSuiteCode: String?, language: SupportedLanguage): String { + testSuiteCode ?: return "" + return when (language) { + SupportedLanguage.Kotlin -> kotlinPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty() + SupportedLanguage.Java -> javaPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty() + } +} + +/** + * Retrieves the imports code from a given test suite code. + * + * @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned. + * @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result. + * @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned. + */ +fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String?): MutableSet { + testSuiteCode ?: return mutableSetOf() + return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence() + .filter { it.contains("^import".toRegex()) } + .filterNot { it.contains("evosuite".toRegex()) } + .filterNot { it.contains("RunWith".toRegex()) } + // classFQN will be null for the top level function + .filterNot { classFQN != null && it.contains(classFQN.toRegex()) } + .toMutableSet() +} + /** * Returns the generated class name for a given test case. * @@ -39,7 +73,7 @@ fun getClassWithTestCaseName(testCaseName: String): String { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun executeTestCaseModificationRequest( - language: Language, + language: SupportedLanguage, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -50,15 +84,7 @@ fun executeTestCaseModificationRequest( // Update Token information val prompt = "For this test:\n ```\n $testCase\n ```\nPerform the following task: $task" - var packageName = "" - testCase.split("\n")[0].let { - if (it.startsWith("package")) { - packageName = it - .removePrefix("package ") - .removeSuffix(";") - .trim() - } - } + val packageName = getPackageFromTestSuiteCode(testCase, language) val response = requestManager.request( language, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt index 689eec798..441e51231 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt @@ -7,8 +7,8 @@ import org.jetbrains.research.testspark.core.data.ChatUserMessage import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestsAssembler -import org.jetbrains.research.testspark.core.utils.Language abstract class RequestManager(var token: String) { enum class SendResult { @@ -31,7 +31,7 @@ abstract class RequestManager(var token: String) { * @return the generated TestSuite, or null and prompt message */ open fun request( - language: Language, + language: SupportedLanguage, prompt: String, indicator: CustomProgressIndicator, packageName: String, @@ -65,7 +65,7 @@ abstract class RequestManager(var token: String) { open fun processResponse( testsAssembler: TestsAssembler, packageName: String, - language: Language, + language: SupportedLanguage, ): LLMResponse { // save the full response in the chat history val response = testsAssembler.getContent() @@ -78,7 +78,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite() return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) @@ -97,7 +97,7 @@ abstract class RequestManager(var token: String) { open fun processUserFeedbackResponse( testsAssembler: TestsAssembler, packageName: String, - language: Language, + language: SupportedLanguage, ): LLMResponse { val response = testsAssembler.getContent() @@ -108,7 +108,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite() return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 278d58655..036e87a0d 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -78,7 +78,7 @@ internal class PromptBuilder(private var prompt: String) { fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" } for (interestingClass in interestingClasses) { - if (interestingClass.qualifiedName.startsWith("java")) { + if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) { continue } @@ -88,7 +88,9 @@ internal class PromptBuilder(private var prompt: String) { // Skip java methods // TODO: checks for java methods should be done by a caller to make // this class as abstract and language agnostic as possible. - if (method.containingClassQualifiedName.startsWith("java")) { + if (method.containingClassQualifiedName.startsWith("java") || + method.containingClassQualifiedName.startsWith("kotlin") + ) { continue } @@ -106,8 +108,11 @@ internal class PromptBuilder(private var prompt: String) { ) = apply { val keyword = "\$${PromptKeyword.POLYMORPHISM.text}" if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) { - var fullText = "" - + // If polymorphismRelations is not empty, we add an instruction to avoid mocking classes if an instantiation of a sub-class is applicable + var fullText = when { + polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable" + else -> "" + } polymorphismRelations.forEach { entry -> for (currentSubClass in entry.value) { val subClassTypeName = when (currentSubClass.classType) { diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt index 3afbd3cff..72340867a 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt @@ -19,7 +19,7 @@ class PromptGenerator( fun generatePromptForClass(interestingClasses: List, testSamplesCode: String): String { val prompt = PromptBuilder(promptTemplates.classPrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName(context.cut.qualifiedName) + .insertName(context.cut!!.qualifiedName) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) .insertCodeUnderTest(context.cut.fullText, context.classesToTest) @@ -44,10 +44,12 @@ class PromptGenerator( method: MethodRepresentation, interestingClassesFromMethod: List, testSamplesCode: String, + packageName: String, ): String { + val name = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}" val prompt = PromptBuilder(promptTemplates.methodPrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName("${context.cut.qualifiedName}.${method.name}") + .insertName(name) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) .insertCodeUnderTest(method.text, context.classesToTest) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt index 4094de1aa..6b87e8941 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt @@ -10,7 +10,10 @@ import org.jetbrains.research.testspark.core.data.ClassType * @property polymorphismRelations A map where the key represents a ClassRepresentation object and the value is a list of its detected subclasses. */ data class PromptGenerationContext( - val cut: ClassRepresentation, + /** + * The cut is null when we want to generate tests for top-level function + */ + val cut: ClassRepresentation?, val classesToTest: List, val polymorphismRelations: Map>, val promptConfiguration: PromptConfiguration, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt new file mode 100644 index 000000000..4b4de90c8 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt @@ -0,0 +1,8 @@ +package org.jetbrains.research.testspark.core.test + +/** + * Language ID string should be the same as the language name in com.intellij.lang.Language + */ +enum class SupportedLanguage(val languageId: String) { + Java("JAVA"), Kotlin("kotlin") +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt new file mode 100644 index 000000000..450400ac3 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt @@ -0,0 +1,21 @@ +package org.jetbrains.research.testspark.core.test + +import org.jetbrains.research.testspark.core.test.data.TestLine + +interface TestBodyPrinter { + /** + * Generates a test body as a string based on the provided parameters. + * + * @param testInitiatedText A string containing the upper part of the test case. + * @param lines A mutable list of `TestLine` objects representing the lines of the test body. + * @param throwsException The exception type that the test function throws, if any. + * @param name The name of the test function. + * @return A string representing the complete test body. + */ + fun printTestBody( + testInitiatedText: String, + lines: MutableList, + throwsException: String, + name: String, + ): String +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt index bc4d40617..b49281aaf 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt @@ -1,32 +1,24 @@ package org.jetbrains.research.testspark.core.test -import io.github.oshai.kotlinlogging.KotlinLogging import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.CommandLineRunner import org.jetbrains.research.testspark.core.utils.DataFilesUtil -import java.io.File data class TestCasesCompilationResult( val allTestCasesCompilable: Boolean, val compilableTestCases: MutableSet, ) -/** - * TestCompiler is a class that is responsible for compiling generated test cases using the proper javac. - * It provides methods for compiling test cases and code files. - */ -open class TestCompiler( - private val javaHomeDirectoryPath: String, +abstract class TestCompiler( private val libPaths: List, private val junitLibPaths: List, ) { - private val log = KotlinLogging.logger { this::class.java } - /** - * Compiles the generated files with test cases using the proper javac. + * Compiles a list of test cases and returns the compilation result. * - * @return true if all the provided test cases are successfully compiled, - * otherwise returns false. + * @param generatedTestCasesPaths A list of file paths where the generated test cases are located. + * @param buildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case. + * @param testCases A mutable list of `TestCaseGeneratedByLLM` objects representing the test cases to be compiled. + * @return A `TestCasesCompilationResult` object containing the overall compilation success status and a set of compilable test cases. */ fun compileTestCases( generatedTestCasesPaths: List, @@ -51,45 +43,11 @@ open class TestCompiler( * Compiles the code at the specified path using the provided project build path. * * @param path The path of the code file to compile. - * @param projectBuildPath The project build path to use during compilation. + * @param projectBuildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case. * @return A pair containing a boolean value indicating whether the compilation was successful (true) or not (false), * and a string message describing any error encountered during compilation. */ - fun compileCode(path: String, projectBuildPath: String): Pair { - // find the proper javac - val javaCompile = File(javaHomeDirectoryPath).walk() - .filter { - val isCompilerName = if (DataFilesUtil.isWindows()) it.name.equals("javac.exe") else it.name.equals("javac") - isCompilerName && it.isFile - } - .firstOrNull() - - if (javaCompile == null) { - val msg = "Cannot find java compiler 'javac' at '$javaHomeDirectoryPath'" - log.error { msg } - throw RuntimeException(msg) - } - - println("javac found at '${javaCompile.absolutePath}'") - - // compile file - val errorMsg = CommandLineRunner.run( - arrayListOf( - javaCompile.absolutePath, - "-cp", - "\"${getPath(projectBuildPath)}\"", - path, - ), - ) - - log.info { "Error message: '$errorMsg'" } - - // create .class file path - val classFilePath = path.replace(".java", ".class") - - // check is .class file exists - return Pair(File(classFilePath).exists(), errorMsg) - } + abstract fun compileCode(path: String, projectBuildPath: String): Pair /** * Generates the path for the command by concatenating the necessary paths. @@ -97,7 +55,7 @@ open class TestCompiler( * @param buildPath The path of the build file. * @return The generated path as a string. */ - fun getPath(buildPath: String): String { + fun getClassPaths(buildPath: String): String { // create the path for the command val separator = DataFilesUtil.classpathSeparator val dependencyLibPath = libPaths.joinToString(separator.toString()) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt similarity index 87% rename from core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt index a0551ed7c..60c4016d4 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt @@ -1,4 +1,4 @@ -package org.jetbrains.research.testspark.core.test.parsers +package org.jetbrains.research.testspark.core.test import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM @@ -11,7 +11,7 @@ data class TestCaseParseResult( interface TestSuiteParser { /** - * Extracts test cases from raw text and generates a test suite using the given package name. + * Extracts test cases from raw text and generates a test suite. * * @param rawText The raw text provided by the LLM that contains the generated test cases. * @return A GeneratedTestSuite instance containing the extracted test cases. diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt index 6e5a4e127..0d9c672de 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt @@ -1,7 +1,6 @@ package org.jetbrains.research.testspark.core.test import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language abstract class TestsAssembler { private var rawText = "" @@ -33,10 +32,9 @@ abstract class TestsAssembler { } /** - * Extracts test cases from raw text and generates a TestSuite using the given package name. + * Extracts test cases from raw text and generates a TestSuite. * - * @param packageName The package name to be set in the generated TestSuite. - * @return A TestSuiteGeneratedByLLM object containing the extracted test cases and package name. + * @return A TestSuiteGeneratedByLLM object containing information about the extracted test cases. */ - abstract fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? + abstract fun assembleTestSuite(): TestSuiteGeneratedByLLM? } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt index 1673fea4a..b9d50132c 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt @@ -4,6 +4,7 @@ package org.jetbrains.research.testspark.core.test * The TestPersistentStorage interface represents a contract for saving generated tests to a specified file system location. */ interface TestsPersistentStorage { + /** * Save the generated tests to a specified directory. * diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/CodeType.kt similarity index 73% rename from src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/CodeType.kt index 8e91aded4..12f18eb54 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/CodeType.kt @@ -1,4 +1,4 @@ -package org.jetbrains.research.testspark.data +package org.jetbrains.research.testspark.core.test.data /** * Enum class, which contains all code elements for which it is possible to request test generation. diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt index 6ef9f6907..2a565e82e 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt @@ -1,5 +1,7 @@ package org.jetbrains.research.testspark.core.test.data +import org.jetbrains.research.testspark.core.test.TestBodyPrinter + /** * * Represents a test case generated by LLM. @@ -11,6 +13,7 @@ data class TestCaseGeneratedByLLM( var expectedException: String = "", var throwsException: String = "", var lines: MutableList = mutableListOf(), + val printTestBodyStrategy: TestBodyPrinter, ) { /** @@ -104,31 +107,7 @@ data class TestCaseGeneratedByLLM( * @return a string containing the body of test case */ private fun printTestBody(testInitiatedText: String): String { - var testFullText = testInitiatedText - - // start writing the test signature - testFullText += "\n\tpublic void $name() " - - // add throws exception if exists - if (throwsException.isNotBlank()) { - testFullText += "throws $throwsException" - } - - // start writing the test lines - testFullText += "{\n" - - // write each line - lines.forEach { line -> - testFullText += when (line.type) { - TestLineType.BREAK -> "\t\t\n" - else -> "\t\t${line.text}\n" - } - } - - // close test case - testFullText += "\t}\n" - - return testFullText + return printTestBodyStrategy.printTestBody(testInitiatedText, lines, throwsException, name) } /** diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt index 211063bb7..4fac9b8b9 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt @@ -4,12 +4,12 @@ package org.jetbrains.research.testspark.core.test.data * Represents a test suite generated by LLM. * * @property imports The set of import statements in the test suite. - * @property packageString The package string of the test suite. + * @property packageName The package name of the test suite. * @property testCases The list of test cases in the test suite. */ data class TestSuiteGeneratedByLLM( var imports: Set = emptySet(), - var packageString: String = "", + var packageName: String = "", var runWith: String = "", var otherInfo: String = "", var testCases: MutableList = mutableListOf(), diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/TestCompilationDependencies.kt similarity index 96% rename from core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/TestCompilationDependencies.kt index 2e78b0b50..622ab0c98 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/TestCompilationDependencies.kt @@ -6,7 +6,7 @@ import org.jetbrains.research.testspark.core.data.JarLibraryDescriptor * The class represents a list of dependencies required for java test compilation. * The libraries listed are used during test suite/test case compilation. */ -class JavaTestCompilationDependencies { +class TestCompilationDependencies { companion object { fun getJarDescriptors() = listOf( JarLibraryDescriptor( diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt new file mode 100644 index 000000000..279badc57 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt @@ -0,0 +1,32 @@ +package org.jetbrains.research.testspark.core.test.java + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.TestSuiteParser +import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy +import org.jetbrains.research.testspark.core.utils.javaImportPattern + +class JavaJUnitTestSuiteParser( + private var packageName: String, + private val junitVersion: JUnitVersion, + private val testBodyPrinter: TestBodyPrinter, +) : TestSuiteParser { + override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { + val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Java) + if (packageInsideTestText.isNotBlank()) { + packageName = packageInsideTestText + } + + return JUnitTestSuiteParserStrategy.parseJUnitTestSuite( + rawText, + junitVersion, + javaImportPattern, + packageName, + testNamePattern = "void", + testBodyPrinter, + ) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt new file mode 100644 index 000000000..bafbcaf13 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt @@ -0,0 +1,40 @@ +package org.jetbrains.research.testspark.core.test.java + +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.data.TestLine +import org.jetbrains.research.testspark.core.test.data.TestLineType + +class JavaTestBodyPrinter : TestBodyPrinter { + override fun printTestBody( + testInitiatedText: String, + lines: MutableList, + throwsException: String, + name: String, + ): String { + var testFullText = testInitiatedText + + // start writing the test signature + testFullText += "\n\tpublic void $name() " + + // add throws exception if exists + if (throwsException.isNotBlank()) { + testFullText += "throws $throwsException" + } + + // start writing the test lines + testFullText += "{\n" + + // write each line + lines.forEach { line -> + testFullText += when (line.type) { + TestLineType.BREAK -> "\t\t\n" + else -> "\t\t${line.text}\n" + } + } + + // close test case + testFullText += "\t}\n" + + return testFullText + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt new file mode 100644 index 000000000..98f0a3d0c --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt @@ -0,0 +1,53 @@ +package org.jetbrains.research.testspark.core.test.java + +import io.github.oshai.kotlinlogging.KotlinLogging +import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.utils.CommandLineRunner +import org.jetbrains.research.testspark.core.utils.DataFilesUtil +import java.io.File + +class JavaTestCompiler( + libPaths: List, + junitLibPaths: List, + private val javaHomeDirectoryPath: String, +) : TestCompiler(libPaths, junitLibPaths) { + + private val log = KotlinLogging.logger { this::class.java } + + override fun compileCode(path: String, projectBuildPath: String): Pair { + val classPaths = "\"${getClassPaths(projectBuildPath)}\"" + // find the proper javac + val javaCompile = File(javaHomeDirectoryPath).walk() + .filter { + val isCompilerName = + if (DataFilesUtil.isWindows()) it.name.equals("javac.exe") else it.name.equals("javac") + isCompilerName && it.isFile + } + .firstOrNull() + + if (javaCompile == null) { + val msg = "Cannot find java compiler 'javac' at '$javaHomeDirectoryPath'" + log.error { msg } + throw RuntimeException(msg) + } + + println("javac found at '${javaCompile.absolutePath}'") + + // compile file + val errorMsg = CommandLineRunner.run( + arrayListOf( + javaCompile.absolutePath, + "-cp", + classPaths, + path, + ), + ) + + log.info { "Error message: '$errorMsg'" } + // create .class file path + val classFilePath = path.replace(".java", ".class") + + // check is .class file exists + return Pair(File(classFilePath).exists(), errorMsg) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt new file mode 100644 index 000000000..18b164810 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt @@ -0,0 +1,32 @@ +package org.jetbrains.research.testspark.core.test.kotlin + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.TestSuiteParser +import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy +import org.jetbrains.research.testspark.core.utils.kotlinImportPattern + +class KotlinJUnitTestSuiteParser( + private var packageName: String, + private val junitVersion: JUnitVersion, + private val testBodyPrinter: TestBodyPrinter, +) : TestSuiteParser { + override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { + val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Kotlin) + if (packageInsideTestText.isNotBlank()) { + packageName = packageInsideTestText + } + + return JUnitTestSuiteParserStrategy.parseJUnitTestSuite( + rawText, + junitVersion, + kotlinImportPattern, + packageName, + testNamePattern = "fun", + testBodyPrinter, + ) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt new file mode 100644 index 000000000..a1a9dc8df --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt @@ -0,0 +1,40 @@ +package org.jetbrains.research.testspark.core.test.kotlin + +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.data.TestLine +import org.jetbrains.research.testspark.core.test.data.TestLineType + +class KotlinTestBodyPrinter : TestBodyPrinter { + override fun printTestBody( + testInitiatedText: String, + lines: MutableList, + throwsException: String, + name: String, + ): String { + var testFullText = testInitiatedText + + // start writing the test signature + testFullText += "\n\tfun $name() " + + // add throws exception if exists + if (throwsException.isNotBlank()) { + testFullText += "throws $throwsException" + } + + // start writing the test lines + testFullText += "{\n" + + // write each line + lines.forEach { line -> + testFullText += when (line.type) { + TestLineType.BREAK -> "\t\t\n" + else -> "\t\t${line.text}\n" + } + } + + // close test case + testFullText += "\t}\n" + + return testFullText + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt new file mode 100644 index 000000000..8d61ce68e --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt @@ -0,0 +1,31 @@ +package org.jetbrains.research.testspark.core.test.kotlin + +import io.github.oshai.kotlinlogging.KotlinLogging +import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.utils.CommandLineRunner + +class KotlinTestCompiler(libPaths: List, junitLibPaths: List) : + TestCompiler(libPaths, junitLibPaths) { + + private val log = KotlinLogging.logger { this::class.java } + + override fun compileCode(path: String, projectBuildPath: String): Pair { + log.info { "[KotlinTestCompiler] Compiling ${path.substringAfterLast('/')}" } + + val classPaths = "\"${getClassPaths(projectBuildPath)}\"" + // Compile file + val errorMsg = CommandLineRunner.run( + arrayListOf( + "kotlinc", + "-cp", + classPaths, + path, + ), + ) + + log.info { "Error message: '$errorMsg'" } + + // No need to save the .class file for kotlin, so checking the error message is enough + return Pair(errorMsg.isBlank(), errorMsg) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt deleted file mode 100644 index a8728bbf2..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt +++ /dev/null @@ -1,22 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers.java - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.strategies.JUnitTestSuiteParserStrategy - -class JavaJUnitTestSuiteParser( - private val packageName: String, - private val junitVersion: JUnitVersion, - private val importPattern: Regex, -) : TestSuiteParser { - override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { - return JUnitTestSuiteParserStrategy.parseTestSuite( - rawText, - junitVersion, - importPattern, - packageName, - testNamePattern = "void", - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt deleted file mode 100644 index 09bdbc627..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt +++ /dev/null @@ -1,22 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers.kotlin - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.strategies.JUnitTestSuiteParserStrategy - -class KotlinJUnitTestSuiteParser( - private val packageName: String, - private val junitVersion: JUnitVersion, - private val importPattern: Regex, -) : TestSuiteParser { - override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { - return JUnitTestSuiteParserStrategy.parseTestSuite( - rawText, - junitVersion, - importPattern, - packageName, - testNamePattern = "fun", - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt deleted file mode 100644 index 98c6827c5..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt +++ /dev/null @@ -1,173 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers.strategies - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM -import org.jetbrains.research.testspark.core.test.data.TestLine -import org.jetbrains.research.testspark.core.test.data.TestLineType -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestCaseParseResult - -class JUnitTestSuiteParserStrategy { - companion object { - fun parseTestSuite( - rawText: String, - junitVersion: JUnitVersion, - importPattern: Regex, - packageName: String, - testNamePattern: String, - ): TestSuiteGeneratedByLLM? { - if (rawText.isBlank()) { - return null - } - - try { - var rawCode = rawText - - if (rawText.contains("```")) { - rawCode = rawText.split("```")[1] - } - - // save imports - val imports = importPattern.findAll(rawCode, 0) - .map { it.groupValues[0] } - .toSet() - - // save RunWith - val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" - - val testSet: MutableList = rawCode.split("@Test").toMutableList() - - // save annotations and pre-set methods - val otherInfo: String = run { - val otherInfoList = testSet.removeAt(0).split("{").toMutableList() - otherInfoList.removeFirst() - val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" - otherInfo.ifBlank { "" } - } - - // Save the main test cases - val testCases: MutableList = mutableListOf() - val testCaseParser = JUnitTestCaseParser() - - testSet.forEach ca@{ - val rawTest = "@Test$it" - - val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) - val result: TestCaseParseResult = - testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern) // /// - - if (result.errorOccurred) { - println("WARNING: ${result.errorMessage}") - return@ca - } - - val currentTest = result.testCase!! - - // TODO: make logging work - // log.info("New test case: $currentTest") - println("New test case: $currentTest") - - testCases.add(currentTest) - } - - val testSuite = TestSuiteGeneratedByLLM( - imports = imports, - packageString = packageName, - runWith = runWith, - otherInfo = otherInfo, - testCases = testCases, - ) - - return testSuite - } catch (e: Exception) { - return null - } - } - } -} - -private class JUnitTestCaseParser { - fun parse(rawTest: String, isLastTestCaseInTestSuite: Boolean, testNamePattern: String): TestCaseParseResult { - var expectedException = "" - var throwsException = "" - val testLines: MutableList = mutableListOf() - - // Get expected Exception - if (rawTest.startsWith("@Test(expected =")) { - expectedException = rawTest.split(")")[0].trim() - } - - // Get unexpected exceptions - /* Each test case should follow fun {...} - Tests do not return anything so it is safe to consider that void always appears before test case name - */ - val voidString = testNamePattern - if (!rawTest.contains(voidString)) { - return TestCaseParseResult( - testCase = null, - errorMessage = "The raw Test does not contain $voidString:\n $rawTest", - errorOccurred = true, - ) - } - val interestingPartOfSignature = rawTest.split(voidString)[1] - .split("{")[0] - .split("()")[1] - .trim() - - if (interestingPartOfSignature.contains("throws")) { - throwsException = interestingPartOfSignature.split("throws")[1].trim() - } - - // Get test name - val testName: String = rawTest.split(voidString)[1] - .split("()")[0] - .trim() - - // Get test body and remove opening bracket - var testBody = rawTest.split("{").toMutableList().apply { removeFirst() } - .joinToString("{").trim() - - // remove closing bracket - val tempList = testBody.split("}").toMutableList() - tempList.removeLast() - - if (isLastTestCaseInTestSuite) { - // it is the last test, thus we should remove another closing bracket - if (tempList.isNotEmpty()) { - tempList.removeLast() - } else { - println("WARNING: the final test does not have the enclosing bracket:\n $testBody") - } - } - - testBody = tempList.joinToString("}") - - // Save each line - val rawLines = testBody.split("\n").toMutableList() - rawLines.forEach { rawLine -> - val line = rawLine.trim() - - val type: TestLineType = when { - line.startsWith("//") -> TestLineType.COMMENT - line.isBlank() -> TestLineType.BREAK - line.lowercase().startsWith("assert") -> TestLineType.ASSERTION - else -> TestLineType.CODE - } - - testLines.add(TestLine(type, line)) - } - - val currentTest = TestCaseGeneratedByLLM( - name = testName, - expectedException = expectedException, - throwsException = throwsException, - lines = testLines, - ) - - return TestCaseParseResult( - testCase = currentTest, - errorMessage = "", - errorOccurred = false, - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt new file mode 100644 index 000000000..7bc818cd0 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt @@ -0,0 +1,175 @@ +package org.jetbrains.research.testspark.core.test.strategies + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.TestCaseParseResult +import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM +import org.jetbrains.research.testspark.core.test.data.TestLine +import org.jetbrains.research.testspark.core.test.data.TestLineType +import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM + +class JUnitTestSuiteParserStrategy { + companion object { + fun parseJUnitTestSuite( + rawText: String, + junitVersion: JUnitVersion, + importPattern: Regex, + packageName: String, + testNamePattern: String, + printTestBodyStrategy: TestBodyPrinter, + ): TestSuiteGeneratedByLLM? { + if (rawText.isBlank()) { + return null + } + + try { + val rawCode = if (rawText.contains("```")) rawText.split("```")[1] else rawText + + // save imports + val imports = importPattern.findAll(rawCode) + .map { it.groupValues[0] } + .toSet() + + // save RunWith + val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" + + val testSet: MutableList = rawCode.split("@Test").toMutableList() + + // save annotations and pre-set methods + val otherInfo: String = run { + val otherInfoList = testSet.removeAt(0).split("{").toMutableList() + otherInfoList.removeFirst() + val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" + otherInfo.ifBlank { "" } + } + + // Save the main test cases + val testCases: MutableList = mutableListOf() + val testCaseParser = JUnitTestCaseParser() + + testSet.forEach ca@{ + val rawTest = "@Test$it" + + val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) + val result: TestCaseParseResult = + testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern, printTestBodyStrategy) + + if (result.errorOccurred) { + println("WARNING: ${result.errorMessage}") + return@ca + } + + val currentTest = result.testCase!! + + // TODO: make logging work + // log.info("New test case: $currentTest") + + testCases.add(currentTest) + } + + val testSuite = TestSuiteGeneratedByLLM( + imports = imports, + packageName = packageName, + runWith = runWith, + otherInfo = otherInfo, + testCases = testCases, + ) + + return testSuite + } catch (e: Exception) { + return null + } + } + } + + private class JUnitTestCaseParser { + fun parse( + rawTest: String, + isLastTestCaseInTestSuite: Boolean, + testNamePattern: String, + printTestBodyStrategy: TestBodyPrinter, + ): TestCaseParseResult { + var expectedException = "" + var throwsException = "" + val testLines: MutableList = mutableListOf() + + // Get expected Exception + if (rawTest.startsWith("@Test(expected =")) { + expectedException = rawTest.split(")")[0].trim() + } + + // Get unexpected exceptions + /* Each test case should follow fun {...} + Tests do not return anything so it is safe to consider that void always appears before test case name + */ + if (!rawTest.contains(testNamePattern)) { + return TestCaseParseResult( + testCase = null, + errorMessage = "The raw Test does not contain $testNamePattern:\n $rawTest", + errorOccurred = true, + ) + } + val interestingPartOfSignature = rawTest.split(testNamePattern)[1] + .split("{")[0] + .split("()")[1] + .trim() + + if (interestingPartOfSignature.contains("throws")) { + throwsException = interestingPartOfSignature.split("throws")[1].trim() + } + + // Get test name + val testName: String = rawTest.split(testNamePattern)[1] + .split("()")[0] + .trim() + + // Get test body and remove opening bracket + var testBody = rawTest.split("{").toMutableList().apply { removeFirst() } + .joinToString("{").trim() + + // remove closing bracket + val tempList = testBody.split("}").toMutableList() + tempList.removeLast() + + if (isLastTestCaseInTestSuite) { + // it is the last test, thus we should remove another closing bracket + if (tempList.isNotEmpty()) { + tempList.removeLast() + } else { + println("WARNING: the final test does not have the enclosing bracket:\n $testBody") + } + } + + testBody = tempList.joinToString("}") + + // Save each line + val rawLines = testBody.split("\n").toMutableList() + rawLines.forEach { rawLine -> + val line = rawLine.trim() + + val type: TestLineType = when { + line.startsWith("//") -> TestLineType.COMMENT + line.isBlank() -> TestLineType.BREAK + line.lowercase().startsWith("assert") -> TestLineType.ASSERTION + else -> TestLineType.CODE + } + + testLines.add(TestLine(type, line)) + } + + val currentTest = TestCaseGeneratedByLLM( + name = testName, + expectedException = expectedException, + throwsException = throwsException, + lines = testLines, + printTestBodyStrategy = printTestBodyStrategy, + ) + + return TestCaseParseResult( + testCase = currentTest, + errorMessage = "", + errorOccurred = false, + ) + } + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt deleted file mode 100644 index 250ec7cba..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt +++ /dev/null @@ -1,8 +0,0 @@ -package org.jetbrains.research.testspark.core.utils - -/** - * Language ID string should be the same as the language name in com.intellij.lang.Language - */ -enum class Language(val languageId: String) { - Java("JAVA"), Kotlin("Kotlin") -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt index 95903bf8c..fb1da6841 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt @@ -6,9 +6,17 @@ val javaImportPattern = options = setOf(RegexOption.MULTILINE), ) +/** + * Parse all the possible Kotlin import patterns + * + * import org.mockito.Mockito.`when` + * import kotlin.math.cos + * import kotlin.math.* + * import kotlin.math.PI as piValue + */ val kotlinImportPattern = Regex( - pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?", + pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?(`\\w*`)?", options = setOf(RegexOption.MULTILINE), ) diff --git a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt index 2ebcde0c9..63fbd0abc 100644 --- a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt +++ b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt @@ -2,14 +2,17 @@ package org.jetbrains.research.testspark.core.test.parsers.kotlin import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.kotlinImportPattern +import org.jetbrains.research.testspark.core.test.kotlin.KotlinJUnitTestSuiteParser +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestBodyPrinter +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test -import kotlin.test.assertNotNull class KotlinJUnitTestSuiteParserTest { @Test - fun testFunction() { + fun testParseTestSuite() { val text = """ ```kotlin import org.junit.jupiter.api.Assertions.* @@ -109,17 +112,149 @@ class KotlinJUnitTestSuiteParserTest { } ``` """.trimIndent() - val parser = KotlinJUnitTestSuiteParser("org.my.package", JUnitVersion.JUnit5, kotlinImportPattern) + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assertTrue(testSuite!!.imports.contains("import org.mockito.Mockito.*")) + assertTrue(testSuite.imports.contains("import org.test.Message as TestMessage")) + assertTrue(testSuite.imports.contains("import org.mockito.kotlin.mock")) + + val expectedTestCasesNames = listOf( + "compileTestCases_AllCompilableTest", + "compileTestCases_NoneCompilableTest", + "compileTestCases_SomeCompilableTest", + "compileTestCases_EmptyTestCasesTest", + "compileTestCases_omg", + ) + + testSuite.testCases.forEachIndexed { index, testCase -> + val expected = expectedTestCasesNames[index] + assertEquals(expected, testCase.name) { "${index + 1}st test case has incorrect name" } + } + + assertTrue(testSuite.testCases[4].expectedException.isNotBlank()) + } + + @Test + fun testParseEmptyTestSuite() { + val text = """ + ```kotlin + package com.example.testsuite + + class EmptyTestClass { + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assertEquals(testSuite!!.packageName, "com.example.testsuite") + assertTrue(testSuite.testCases.isEmpty()) + } + + @Test + fun testParseSingleTestCase() { + val text = """ + ```kotlin + import org.junit.jupiter.api.Test + + class SingleTestCaseClass { + @Test + fun singleTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) assertNotNull(testSuite) - assert(testSuite.imports.contains("import org.mockito.Mockito.*")) - assert(testSuite.imports.contains("import org.test.Message as TestMessage")) - assert(testSuite.imports.contains("import org.mockito.kotlin.mock")) - assert(testSuite.testCases[0].name == "compileTestCases_AllCompilableTest") - assert(testSuite.testCases[1].name == "compileTestCases_NoneCompilableTest") - assert(testSuite.testCases[2].name == "compileTestCases_SomeCompilableTest") - assert(testSuite.testCases[3].name == "compileTestCases_EmptyTestCasesTest") - assert(testSuite.testCases[4].name == "compileTestCases_omg") - assert(testSuite.testCases[4].expectedException.isNotBlank()) + assertEquals(1, testSuite!!.testCases.size) + assertEquals("singleTestCase", testSuite.testCases[0].name) + } + + @Test + fun testParseTwoTestCases() { + val text = """ + ```kotlin + import org.junit.jupiter.api.Test + + class TwoTestCasesClass { + @Test + fun firstTestCase() { + // Test case implementation + } + + @Test + fun secondTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assertEquals(2, testSuite!!.testCases.size) + assertEquals("firstTestCase", testSuite.testCases[0].name) + assertEquals("secondTestCase", testSuite.testCases[1].name) + } + + @Test + fun testParseTwoTestCasesWithDifferentPackage() { + val code1 = """ + ```kotlin + package org.pkg1 + + import org.junit.jupiter.api.Test + + class TestCasesClass1 { + @Test + fun firstTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val code2 = """ + ```kotlin + package org.pkg2 + + import org.junit.jupiter.api.Test + + class 2TestCasesClass { + @Test + fun firstTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter) + + // packageName will be set to 'org.pkg1' + val testSuite1 = parser.parseTestSuite(code1) + + val testSuite2 = parser.parseTestSuite(code2) + + assertNotNull(testSuite1) + assertNotNull(testSuite2) + assertEquals("org.pkg1", testSuite1!!.packageName) + assertEquals("org.pkg2", testSuite2!!.packageName) } } diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt index 007bdbff7..087485827 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt @@ -14,6 +14,7 @@ import org.jetbrains.research.testspark.core.utils.javaImportPattern import org.jetbrains.research.testspark.core.utils.javaPackagePattern import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper +import org.jetbrains.research.testspark.langwrappers.strategies.JavaKotlinClassTextExtractor class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { override val name: String get() = psiClass.name ?: "" @@ -33,29 +34,12 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { override val containingFile: PsiFile get() = psiClass.containingFile override val fullText: String - get() { - var fullText = "" - val fileText = psiClass.containingFile.text - - // get package - javaPackagePattern.findAll(fileText).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n\n" - } - - // get imports - javaImportPattern.findAll(fileText).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n" - } - - // Add class code - fullText += psiClass.text - - return fullText - } + get() = JavaKotlinClassTextExtractor().extract( + psiClass.containingFile, + psiClass.text, + javaPackagePattern, + javaImportPattern, + ) override val classType: ClassType get() { @@ -68,6 +52,8 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { return ClassType.CLASS } + override val rBrace: Int? = psiClass.rBrace?.textRange?.startOffset + override fun searchSubclasses(project: Project): Collection { val scope = GlobalSearchScope.projectScope(project) val query = ClassInheritorsSearch.search(psiClass, scope, false) diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt index 8b513deda..d2b8dac35 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt @@ -4,23 +4,27 @@ import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.editor.Caret +import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.project.Project import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiClass import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile +import com.intellij.psi.PsiJavaFile import com.intellij.psi.PsiMethod import com.intellij.psi.util.PsiTreeUtil import com.intellij.psi.util.PsiTypesUtil -import org.jetbrains.research.testspark.langwrappers.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType +import org.jetbrains.research.testspark.langwrappers.CodeTypeDisplayName import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { - override val language: Language get() = Language.Java + override val language: SupportedLanguage get() = SupportedLanguage.Java private val log = Logger.getInstance(this::class.java) @@ -84,7 +88,7 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { project: Project, classesToTest: MutableList, caretOffset: Int, - maxPolymorphismDepth: Int, // check if cut has any non-java super class + maxPolymorphismDepth: Int, ) { val cutPsiClass = getSurroundingClass(caretOffset)!! var currentPsiClass = cutPsiClass @@ -138,39 +142,44 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { } override fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet { - val interestingPsiClasses = cut.getInterestingPsiClassesWithQualifiedNames(psiMethod) + // The cut is always not null for Java, because all functions are always inside the class + val interestingPsiClasses = cut!!.getInterestingPsiClassesWithQualifiedNames(psiMethod) log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}") return interestingPsiClasses } - override fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? { - val result: ArrayList = arrayListOf() + override fun getCurrentListOfCodeTypes(e: AnActionEvent): List { + val result: ArrayList = arrayListOf() val caret: Caret = - e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result.toArray() + e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result val javaPsiClassWrapped = getSurroundingClass(caret.offset) as JavaPsiClassWrapper? val javaPsiMethodWrapped = getSurroundingMethod(caret.offset) as JavaPsiMethodWrapper? val line: Int? = getSurroundingLine(caret.offset) - javaPsiClassWrapped?.let { result.add(getClassHTMLDisplayName(it)) } - javaPsiMethodWrapped?.let { result.add(getMethodHTMLDisplayName(it)) } - line?.let { result.add(getLineHTMLDisplayName(it)) } - - if (javaPsiClassWrapped != null && javaPsiMethodWrapped != null) { - log.info( - "The test can be generated for: \n " + - " 1) Class ${javaPsiClassWrapped.qualifiedName} \n" + - " 2) Method ${javaPsiMethodWrapped.name} \n" + - " 3) Line $line", - ) - } + javaPsiClassWrapped?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } + javaPsiMethodWrapped?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } + line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) } + + log.info( + "The test can be generated for: \n " + + " 1) Class ${javaPsiClassWrapped?.qualifiedName ?: "no class"} \n" + + " 2) Method ${javaPsiMethodWrapped?.name ?: "no method"} \n" + + " 3) Line $line", + ) - return result.toArray() + return result } + override fun getPackageName() = (psiFile as PsiJavaFile).packageName + + override fun getModuleFromPsiFile() = ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!! + + override fun getDocumentFromPsiFile() = psiFile.fileDocument + override fun getLineHTMLDisplayName(line: Int) = "line $line" override fun getClassHTMLDisplayName(psiClass: PsiClassWrapper): String = diff --git a/java/src/main/resources/META-INF/testspark-java.xml b/java/src/main/resources/META-INF/testspark-java.xml deleted file mode 100644 index 180580ca7..000000000 --- a/java/src/main/resources/META-INF/testspark-java.xml +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt index 8ac75755c..50cc12f0f 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt @@ -21,6 +21,7 @@ import org.jetbrains.research.testspark.core.utils.kotlinImportPattern import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper +import org.jetbrains.research.testspark.langwrappers.strategies.JavaKotlinClassTextExtractor class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWrapper { override val name: String get() = psiClass.name ?: "" @@ -61,29 +62,12 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra override val containingFile: PsiFile get() = psiClass.containingFile override val fullText: String - get() { - var fullText = "" - val fileText = psiClass.containingFile.text - - // get package - kotlinPackagePattern.findAll(fileText, 0).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n\n" - } - - // get imports - kotlinImportPattern.findAll(fileText, 0).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n" - } - - // Add class code - fullText += psiClass.text - - return fullText - } + get() = JavaKotlinClassTextExtractor().extract( + psiClass.containingFile, + psiClass.text, + kotlinPackagePattern, + kotlinImportPattern, + ) override val classType: ClassType get() { @@ -97,6 +81,8 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra } } + override val rBrace: Int? = psiClass.body?.rBrace?.textRange?.startOffset + override fun searchSubclasses(project: Project): Collection { val scope = GlobalSearchScope.projectScope(project) val lightClass = psiClass.toLightClass() @@ -116,11 +102,9 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra method.psiFunction.valueParameters.forEach { parameter -> val typeReference = parameter.typeReference - if (typeReference != null) { - val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) - if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { - interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) - } + val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) + if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { + interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) } } diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt index 13749bd35..ca131f7da 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt @@ -4,6 +4,7 @@ import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.editor.Caret +import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.project.Project import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiClass @@ -16,19 +17,22 @@ import org.jetbrains.kotlin.idea.base.psi.kotlinFqName import org.jetbrains.kotlin.idea.caches.resolve.analyze import org.jetbrains.kotlin.psi.KtClass import org.jetbrains.kotlin.psi.KtClassOrObject +import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtFunction import org.jetbrains.kotlin.psi.KtTypeReference import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils import org.jetbrains.kotlin.resolve.lazy.BodyResolveMode -import org.jetbrains.research.testspark.langwrappers.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType +import org.jetbrains.research.testspark.langwrappers.CodeTypeDisplayName import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper -class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { +class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { - override val language: Language get() = Language.Kotlin + override val language: SupportedLanguage get() = SupportedLanguage.Kotlin private val log = Logger.getInstance(this::class.java) @@ -85,9 +89,10 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { project: Project, classesToTest: MutableList, caretOffset: Int, - maxPolymorphismDepth: Int, // check if cut has any non-java super class + maxPolymorphismDepth: Int, ) { - val cutPsiClass = getSurroundingClass(caretOffset)!! + val cutPsiClass = getSurroundingClass(caretOffset) ?: return + // will be null for the top level function var currentPsiClass = cutPsiClass for (index in 0 until maxPolymorphismDepth) { if (!classesToTest.contains(currentPsiClass)) { @@ -143,39 +148,45 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { } override fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet { - val interestingPsiClasses = cut.getInterestingPsiClassesWithQualifiedNames(psiMethod) + val interestingPsiClasses = + cut?.getInterestingPsiClassesWithQualifiedNames(psiMethod) + ?: (psiMethod as KotlinPsiMethodWrapper).getInterestingPsiClassesWithQualifiedNames() log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}") return interestingPsiClasses } - override fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? { - val result: ArrayList = arrayListOf() + override fun getCurrentListOfCodeTypes(e: AnActionEvent): List { + val result: ArrayList = arrayListOf() val caret: Caret = - e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result.toArray() + e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result val ktClass = getSurroundingClass(caret.offset) val ktFunction = getSurroundingMethod(caret.offset) val line: Int? = getSurroundingLine(caret.offset)?.plus(1) - ktClass?.let { result.add(getClassHTMLDisplayName(it)) } - ktFunction?.let { result.add(getMethodHTMLDisplayName(it)) } - line?.let { result.add(getLineHTMLDisplayName(it)) } - - if (ktClass != null && ktFunction != null) { - log.info( - "The test can be generated for: \n " + - " 1) Class ${ktClass.qualifiedName} \n" + - " 2) Method ${ktFunction.name} \n" + - " 3) Line $line", - ) - } + ktClass?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } + ktFunction?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } + line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) } + + log.info( + "The test can be generated for: \n " + + " 1) Class ${ktClass?.qualifiedName ?: "no class"} \n" + + " 2) Method ${ktFunction?.name ?: "no method"} \n" + + " 3) Line $line", + ) - return result.toArray() + return result } + override fun getPackageName() = (psiFile as KtFile).packageFqName.asString() + + override fun getModuleFromPsiFile() = ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!! + + override fun getDocumentFromPsiFile() = psiFile.fileDocument + override fun getLineHTMLDisplayName(line: Int) = "line $line" override fun getClassHTMLDisplayName(psiClass: PsiClassWrapper): String = @@ -184,7 +195,7 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { override fun getMethodHTMLDisplayName(psiMethod: PsiMethodWrapper): String { psiMethod as KotlinPsiMethodWrapper return when { - psiMethod.isTopLevelFunction -> "top-level function" + psiMethod.isTopLevelFunction -> "top-level function ${psiMethod.name}" psiMethod.isSecondaryConstructor -> "secondary constructor" psiMethod.isPrimaryConstructor -> "constructor" psiMethod.isDefaultMethod -> "default method ${psiMethod.name}" diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt index a142aaaa8..c993fd808 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt @@ -68,6 +68,26 @@ class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper { return lineNumber in startLine..endLine } + /** + * Returns a set of `PsiClassWrapper` instances for non-standard Kotlin classes referenced by the + * parameters of the current function. + * + * @return A mutable set of `PsiClassWrapper` instances representing non-standard Kotlin classes. + */ + fun getInterestingPsiClassesWithQualifiedNames(): MutableSet { + val interestingPsiClasses = mutableSetOf() + + psiFunction.valueParameters.forEach { parameter -> + val typeReference = parameter.typeReference + val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) + if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { + interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) + } + } + + return interestingPsiClasses + } + /** * Generates the return descriptor for a method. * diff --git a/kotlin/src/main/resources/META-INF/testspark-kotlin.xml b/kotlin/src/main/resources/META-INF/testspark-kotlin.xml deleted file mode 100644 index 22e5e05c8..000000000 --- a/kotlin/src/main/resources/META-INF/testspark-kotlin.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/langwrappers/build.gradle.kts b/langwrappers/build.gradle.kts index 74ec82496..317debb35 100644 --- a/langwrappers/build.gradle.kts +++ b/langwrappers/build.gradle.kts @@ -5,7 +5,6 @@ plugins { repositories { mavenCentral() - // Add any other repositories you need } dependencies { @@ -17,7 +16,6 @@ dependencies { intellij { rootProject.properties["platformVersion"]?.let { version.set(it.toString()) } plugins.set(listOf("java")) - downloadSources.set(true) } tasks.named("verifyPlugin") { enabled = false } diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt new file mode 100644 index 000000000..0982b9ced --- /dev/null +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt @@ -0,0 +1,7 @@ +package org.jetbrains.research.testspark.langwrappers + +import com.intellij.psi.PsiFile + +interface LanguageClassTextExtractor { + fun extract(file: PsiFile, classText: String, packagePattern: Regex, importPattern: Regex): String +} diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt index f61dc7a1b..c6f98afeb 100644 --- a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt @@ -1,11 +1,15 @@ package org.jetbrains.research.testspark.langwrappers import com.intellij.openapi.actionSystem.AnActionEvent +import com.intellij.openapi.editor.Document import com.intellij.openapi.project.Project import com.intellij.openapi.vfs.VirtualFile import com.intellij.psi.PsiFile import org.jetbrains.research.testspark.core.data.ClassType -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType + +typealias CodeTypeDisplayName = Pair /** * Interface representing a wrapper for PSI methods, @@ -40,12 +44,14 @@ interface PsiMethodWrapper { * @property name The name of a class * @property qualifiedName The qualified name of the class. * @property text The text of the class. - * @property fullText The source code of the class (with package and imports). - * @property virtualFile - * @property containingFile File where the method is located - * @property superClass The super class of the class * @property methods All methods in the class * @property allMethods All methods in the class and all its superclasses + * @property superClass The super class of the class + * @property virtualFile Virtual file where the class is located + * @property containingFile File where the method is located + * @property fullText The source code of the class (with package and imports). + * @property classType The type of the class + * @property rBrace The offset of the closing brace * */ interface PsiClassWrapper { val name: String @@ -58,6 +64,7 @@ interface PsiClassWrapper { val containingFile: PsiFile val fullText: String val classType: ClassType + val rBrace: Int? /** * Searches for subclasses of the current class within the given project. @@ -81,7 +88,7 @@ interface PsiClassWrapper { * handling the PSI (Program Structure Interface) for different languages. */ interface PsiHelper { - val language: Language + val language: SupportedLanguage /** * Returns the surrounding PsiClass object based on the caret position within the specified PsiFile. @@ -133,7 +140,7 @@ interface PsiHelper { * @return A mutable set of interesting PsiClasses. */ fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet @@ -145,7 +152,7 @@ interface PsiHelper { * The array contains the class display name, method display name (if present), and the line number (if present). * The line number is prefixed with "Line". */ - fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? + fun getCurrentListOfCodeTypes(e: AnActionEvent): List /** * Helper for generating method descriptors for methods. @@ -160,8 +167,8 @@ interface PsiHelper { * * @param project The project in which to collect classes to test. * @param classesToTest The list of classes to test. - * @param psiHelper The PSI helper instance to use for collecting classes. * @param caretOffset The caret offset in the file. + * @param maxPolymorphismDepth Check if cut has any user-defined superclass */ fun collectClassesToTest( project: Project, @@ -170,6 +177,21 @@ interface PsiHelper { maxPolymorphismDepth: Int, ) + /** + * Get the package name of the file. + */ + fun getPackageName(): String + + /** + * Get the module of the file. + */ + fun getModuleFromPsiFile(): com.intellij.openapi.module.Module + + /** + * Get the module of the file. + */ + fun getDocumentFromPsiFile(): Document? + /** * Gets the display line number. * This is used when displaying the name of a method in the GenerateTestsActionMethod menu entry. diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt new file mode 100644 index 000000000..643cdee34 --- /dev/null +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt @@ -0,0 +1,39 @@ +package org.jetbrains.research.testspark.langwrappers.strategies + +import com.intellij.psi.PsiFile +import org.jetbrains.research.testspark.langwrappers.LanguageClassTextExtractor + +/** +Direct implementor for the Java and Kotlin PsiWrappers + */ +class JavaKotlinClassTextExtractor : LanguageClassTextExtractor { + + override fun extract( + file: PsiFile, + classText: String, + packagePattern: Regex, + importPattern: Regex, + ): String { + var fullText = "" + val fileText = file.text + + // get package + packagePattern.findAll(fileText, 0).map { + it.groupValues[0] + }.forEach { + fullText += "$it\n\n" + } + + // get imports + importPattern.findAll(fileText, 0).map { + it.groupValues[0] + }.forEach { + fullText += "$it\n" + } + + // Add class code + fullText += classText + + return fullText + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt index 3b08ca009..a6f342882 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt @@ -17,6 +17,7 @@ import org.jetbrains.research.testspark.actions.llm.LLMSetupPanelFactory import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.display.TestSparkIcons import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider @@ -76,7 +77,6 @@ class TestSparkAction : AnAction() { if (psiHelper == null) { // TODO exception } - e.presentation.isEnabled = psiHelper!!.getCurrentListOfCodeTypes(e) != null } /** @@ -111,18 +111,18 @@ class TestSparkAction : AnAction() { return psiHelper!! } - private val codeTypes = psiHelper.getCurrentListOfCodeTypes(e)!! + private val codeTypes = psiHelper.getCurrentListOfCodeTypes(e) private val caretOffset: Int = e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret!!.offset private val fileUrl = e.dataContext.getData(CommonDataKeys.VIRTUAL_FILE)!!.presentableUrl - private val codeTypeButtons: MutableList = mutableListOf() + private val codeTypeButtons: MutableList> = mutableListOf() private val codeTypeButtonGroup = ButtonGroup() private val nextButton = JButton(PluginLabelsBundle.get("next")) private val cardLayout = CardLayout() private val llmSetupPanelFactory = LLMSetupPanelFactory(e, project) - private val llmSampleSelectorFactory = LLMSampleSelectorFactory(project) + private val llmSampleSelectorFactory = LLMSampleSelectorFactory(project, psiHelper.language) private val evoSuitePanelFactory = EvoSuitePanelFactory(project) init { @@ -198,16 +198,19 @@ class TestSparkAction : AnAction() { testGeneratorPanel.add(llmButton) testGeneratorPanel.add(evoSuiteButton) - for (codeType in codeTypes) { - val button = JRadioButton(codeType as String) - codeTypeButtons.add(button) + for ((codeType, codeTypeName) in codeTypes) { + val button = JRadioButton(codeTypeName) + codeTypeButtons.add(codeType to button) codeTypeButtonGroup.add(button) } val codesToTestPanel = JPanel() codesToTestPanel.add(JLabel("Select the code type:")) - if (codeTypeButtons.size == 1) codeTypeButtons[0].isSelected = true - for (button in codeTypeButtons) codesToTestPanel.add(button) + if (codeTypeButtons.size == 1) { + // A single button is selected by default + codeTypeButtons[0].second.isSelected = true + } + for ((_, button) in codeTypeButtons) codesToTestPanel.add(button) val middlePanel = FormBuilder.createFormBuilder() .setFormLeftIndent(10) @@ -253,7 +256,7 @@ class TestSparkAction : AnAction() { updateNextButton() } - for (button in codeTypeButtons) { + for ((_, button) in codeTypeButtons) { button.addActionListener { llmSetupPanelFactory.setPromptEditorType(button.text) updateNextButton() @@ -330,33 +333,36 @@ class TestSparkAction : AnAction() { if (!testGenerationController.isGeneratorRunning(project)) { val testSamplesCode = llmSampleSelectorFactory.getTestSamplesCode() - if (codeTypeButtons[0].isSelected) { - tool.generateTestsForClass( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) - } else if (codeTypeButtons[1].isSelected) { - tool.generateTestsForMethod( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) - } else if (codeTypeButtons[2].isSelected) { - tool.generateTestsForLine( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) + for ((codeType, button) in codeTypeButtons) { + if (button.isSelected) { + when (codeType) { + CodeType.CLASS -> tool.generateTestsForClass( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + CodeType.METHOD -> tool.generateTestsForMethod( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + CodeType.LINE -> tool.generateTestsForLine( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + } + break + } } } @@ -376,10 +382,7 @@ class TestSparkAction : AnAction() { */ private fun updateNextButton() { val isTestGeneratorButtonGroupSelected = llmButton.isSelected || evoSuiteButton.isSelected - var isCodeTypeButtonGroupSelected = false - for (button in codeTypeButtons) { - isCodeTypeButtonGroupSelected = isCodeTypeButtonGroupSelected || button.isSelected - } + val isCodeTypeButtonGroupSelected = codeTypeButtons.any { it.second.isSelected } nextButton.isEnabled = isTestGeneratorButtonGroupSelected && isCodeTypeButtonGroupSelected if ((llmButton.isSelected && !llmSettingsState.llmSetupCheckBoxSelected && !llmSettingsState.provideTestSamplesCheckBoxSelected) || @@ -393,4 +396,4 @@ class TestSparkAction : AnAction() { } override fun getActionUpdateThread(): ActionUpdateThread = ActionUpdateThread.BGT -} \ No newline at end of file +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt index bb2c5a53f..b6b77a0ff 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt @@ -4,7 +4,8 @@ import com.intellij.openapi.project.Project import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle -import org.jetbrains.research.testspark.java.LLMTestSampleHelper +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.helpers.LLMTestSampleHelper import java.awt.Font import javax.swing.ButtonGroup import javax.swing.JButton @@ -12,7 +13,7 @@ import javax.swing.JLabel import javax.swing.JPanel import javax.swing.JRadioButton -class LLMSampleSelectorFactory(private val project: Project) : PanelFactory { +class LLMSampleSelectorFactory(private val project: Project, private val language: SupportedLanguage) : PanelFactory { // init components private val selectionTypeButtons: MutableList = mutableListOf( JRadioButton(PluginLabelsBundle.get("provideTestSample")), @@ -128,7 +129,7 @@ class LLMSampleSelectorFactory(private val project: Project) : PanelFactory { } addButton.addActionListener { - val testSamplePanelFactory = TestSamplePanelFactory(project, middlePanel, testNames, initialTestCodes) + val testSamplePanelFactory = TestSamplePanelFactory(project, middlePanel, testNames, initialTestCodes, language) testSamplePanelFactories.add(testSamplePanelFactory) val testSamplePanel = testSamplePanelFactory.getTestSamplePanel() val codeScrollPanel = testSamplePanelFactory.getCodeScrollPanel() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt index 69d5db9f3..8afe31fc8 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt @@ -34,7 +34,7 @@ class LLMSetupPanelFactory(e: AnActionEvent, private val project: Project) : Pan private val defaultModulesArray = arrayOf("") private var modelSelector = ComboBox(defaultModulesArray) private var llmUserTokenField = JTextField(30) - private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName)) + private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName)) private val backLlmButton = JButton(PluginLabelsBundle.get("back")) private val okLlmButton = JButton(PluginLabelsBundle.get("next")) private val junitSelector = JUnitCombobox(e) @@ -142,6 +142,10 @@ class LLMSetupPanelFactory(e: AnActionEvent, private val project: Project) : Pan llmSettingsState.grazieToken = llmPlatforms[index].token llmSettingsState.grazieModel = llmPlatforms[index].model } + if (llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + llmSettingsState.huggingFaceToken = llmPlatforms[index].token + llmSettingsState.huggingFaceModel = llmPlatforms[index].model + } } llmSettingsState.junitVersion = junitSelector.selectedItem!! as JUnitVersion diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt index 97cf6d49a..251a45f27 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt @@ -10,6 +10,7 @@ import com.intellij.openapi.ui.ComboBox import com.intellij.ui.LanguageTextField import com.intellij.ui.components.JBScrollPane import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.display.IconButtonCreator import org.jetbrains.research.testspark.display.ModifiedLinesGetter import org.jetbrains.research.testspark.display.TestCaseDocumentCreator @@ -25,11 +26,12 @@ class TestSamplePanelFactory( private val middlePanel: JPanel, private val testNames: MutableList, private val initialTestCodes: MutableList, + private val language: SupportedLanguage, ) { // init components private val currentTestCodes = initialTestCodes.toMutableList() private val languageTextField = LanguageTextField( - Language.findLanguageByID("JAVA"), + Language.findLanguageByID(language.languageId), project, initialTestCodes[0], TestCaseDocumentCreator("TestSample"), diff --git a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt index b8b0654d3..499abf1c1 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt @@ -18,7 +18,8 @@ import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.ProjectContext import org.jetbrains.research.testspark.data.llm.JsonEncoding @@ -26,6 +27,7 @@ import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider import org.jetbrains.research.testspark.progress.HeadlessProgressIndicator import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.services.PluginSettingsService +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.Llm @@ -172,6 +174,12 @@ class TestSparkStarter : ApplicationStarter { // Start test generation val indicator = HeadlessProgressIndicator() val errorMonitor = DefaultErrorMonitor() + val testCompiler = TestCompilerFactory.create( + project, + settingsState.junitVersion, + psiHelper.language, + projectSDKPath.toString(), + ) val uiContext = llmProcessManager.runTestGenerator( indicator, FragmentToTestData(CodeType.CLASS), @@ -192,6 +200,7 @@ class TestSparkStarter : ApplicationStarter { classPath, projectContext, projectSDKPath, + testCompiler, ) } else { println("[TestSpark Starter] Test generation failed") @@ -237,6 +246,7 @@ class TestSparkStarter : ApplicationStarter { classPath: String, projectContext: ProjectContext, projectSDKPath: Path, + testCompiler: TestCompiler, ) { val targetDirectory = "$out${File.separator}${packageList.joinToString(File.separator)}" println("Run tests in $targetDirectory") @@ -246,6 +256,7 @@ class TestSparkStarter : ApplicationStarter { var testcaseName = it.nameWithoutExtension.removePrefix("Generated") testcaseName = testcaseName[0].lowercaseChar() + testcaseName.substring(1) // The current test is compiled and is ready to run jacoco + val testExecutionError = TestProcessor(project, projectSDKPath).createXmlFromJacoco( it.nameWithoutExtension, "$targetDirectory${File.separator}jacoco-${it.nameWithoutExtension}", @@ -254,6 +265,7 @@ class TestSparkStarter : ApplicationStarter { packageList.joinToString("."), out, projectContext, + testCompiler, ) // Saving exception (if exists) thrown during the test execution saveException(testcaseName, targetDirectory, testExecutionError) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt index 0cf79dddb..3c289bb11 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt @@ -1,5 +1,7 @@ package org.jetbrains.research.testspark.data +import org.jetbrains.research.testspark.core.test.data.CodeType + /** * Data about test objects that require test generators. */ diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt index f17e8720b..99b0ec5ab 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt @@ -25,17 +25,20 @@ import org.jetbrains.research.testspark.core.data.Report import org.jetbrains.research.testspark.core.data.TestCase import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.data.llm.JsonEncoding import org.jetbrains.research.testspark.display.custom.IJProgressIndicator -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper import org.jetbrains.research.testspark.helpers.LLMHelper import org.jetbrains.research.testspark.helpers.ReportHelper import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.services.TestCaseDisplayService import org.jetbrains.research.testspark.services.TestsExecutionResultService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestClassCodeAnalyzerFactory +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.test.JUnitTestSuitePresenter @@ -58,7 +61,7 @@ import javax.swing.border.MatteBorder class TestCasePanelFactory( private val project: Project, - private val language: org.jetbrains.research.testspark.core.utils.Language, + private val language: SupportedLanguage, private val testCase: TestCase, editor: Editor, private val checkbox: JCheckBox, @@ -193,7 +196,10 @@ class TestCasePanelFactory( val clipboard: Clipboard = Toolkit.getDefaultToolkit().systemClipboard clipboard.setContents( StringSelection( - project.service().getEditor(testCase.testName)!!.document.text, + when (language) { + SupportedLanguage.Kotlin -> project.service().getEditor(testCase.testName)!!.document.text + SupportedLanguage.Java -> project.service().getEditor(testCase.testName)!!.document.text + }, ), null, ) @@ -386,7 +392,10 @@ class TestCasePanelFactory( } ReportHelper.updateTestCase(project, report, testCase) - project.service().updateUI() + when (language) { + SupportedLanguage.Kotlin -> project.service().updateUI() + SupportedLanguage.Java -> project.service().updateUI() + } } /** @@ -454,12 +463,12 @@ class TestCasePanelFactory( } private fun addTest(testSuite: TestSuiteGeneratedByLLM) { - val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput) + val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput, language) WriteCommandAction.runWriteCommandAction(project) { uiContext.errorMonitor.clear() val code = testSuitePresenter.toString(testSuite) - testCase.testName = JavaClassBuilderHelper.getTestMethodNameFromClassWithTestCase(testCase.testName, code) + testCase.testName = TestClassCodeAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, code) testCase.testCode = code // update numbers @@ -517,15 +526,24 @@ class TestCasePanelFactory( private fun runTest(indicator: CustomProgressIndicator) { indicator.setText("Executing ${testCase.testName}") + val fileName = TestClassCodeAnalyzerFactory.create(language).getFileNameFromTestCaseCode(testCase.testName) + + val testCompiler = TestCompilerFactory.create( + project, + llmSettingsState.junitVersion, + language, + ) + val newTestCase = TestProcessor(project) .processNewTestCase( - "${JavaClassBuilderHelper.getClassFromTestCaseCode(testCase.testCode)}.java", + fileName, testCase.id, testCase.testName, testCase.testCode, - uiContext!!.testGenerationOutput.packageLine, + uiContext!!.testGenerationOutput.packageName, uiContext.testGenerationOutput.resultPath, uiContext.projectContext, + testCompiler, ) testCase.coveredLines = newTestCase.coveredLines @@ -585,13 +603,23 @@ class TestCasePanelFactory( */ private fun remove() { // Remove the test case from the cache - project.service().removeTestCase(testCase.testName) + when (language) { + SupportedLanguage.Kotlin -> project.service().removeTestCase(testCase.testName) + + SupportedLanguage.Java -> project.service().removeTestCase(testCase.testName) + } runTestButton.isEnabled = false isRemoved = true ReportHelper.removeTestCase(project, report, testCase) - project.service().updateUI() + when (language) { + SupportedLanguage.Kotlin -> project.service() + .updateUI() + + SupportedLanguage.Java -> project.service() + .updateUI() + } } /** @@ -663,8 +691,7 @@ class TestCasePanelFactory( * Updates the current test case with the specified test name and test code. */ private fun updateTestCaseInformation() { - testCase.testName = - JavaClassBuilderHelper.getTestMethodNameFromClassWithTestCase(testCase.testName, languageTextField.document.text) + testCase.testName = TestClassCodeAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, languageTextField.document.text) testCase.testCode = languageTextField.document.text } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt index 31cc7b9a6..b8f90918c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt @@ -1,6 +1,5 @@ package org.jetbrains.research.testspark.display -import com.intellij.openapi.components.service import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.progress.Task @@ -8,20 +7,20 @@ import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.display.custom.IJProgressIndicator -import org.jetbrains.research.testspark.services.TestCaseDisplayService +import org.jetbrains.research.testspark.display.strategies.TopButtonsPanelStrategy import java.awt.Dimension import java.util.LinkedList import java.util.Queue import javax.swing.Box import javax.swing.BoxLayout import javax.swing.JButton -import javax.swing.JCheckBox import javax.swing.JLabel import javax.swing.JOptionPane import javax.swing.JPanel -class TopButtonsPanelFactory(private val project: Project) { +class TopButtonsPanelFactory(private val project: Project, private val language: SupportedLanguage) { private var runAllButton: JButton = createRunAllTestButton() private var selectAllButton: JButton = IconButtonCreator.getButton(TestSparkIcons.selectAll, PluginLabelsBundle.get("selectAllTip")) @@ -64,28 +63,26 @@ class TopButtonsPanelFactory(private val project: Project) { * Updates the labels. */ fun updateTopLabels() { - var numberOfPassedTests = 0 - for (testCasePanelFactory in testCasePanelFactories) { - if (testCasePanelFactory.isRemoved()) continue - val error = testCasePanelFactory.getError() - if ((error is String) && error.isEmpty()) { - numberOfPassedTests++ - } - } - testsSelectedLabel.text = String.format( - testsSelectedText, - project.service().getTestsSelected(), - project.service().getTestCasePanels().size, - ) - testsPassedLabel.text = - String.format( + when (language) { + SupportedLanguage.Java -> TopButtonsPanelStrategy.updateTopJavaLabels( + testCasePanelFactories, + testsSelectedLabel, + testsSelectedText, + project, + testsPassedLabel, testsPassedText, - numberOfPassedTests, - project.service().getTestCasePanels().size, + runAllButton, + ) + + SupportedLanguage.Kotlin -> TopButtonsPanelStrategy.updateTopKotlinLabels( + testCasePanelFactories, + testsSelectedLabel, + testsSelectedText, + project, + testsPassedLabel, + testsPassedText, + runAllButton, ) - runAllButton.isEnabled = false - for (testCasePanelFactory in testCasePanelFactories) { - runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled() } } @@ -105,31 +102,20 @@ class TopButtonsPanelFactory(private val project: Project) { * @param selected whether the checkboxes have to be selected or not */ private fun toggleAllCheckboxes(selected: Boolean) { - project.service().getTestCasePanels().forEach { (_, jPanel) -> - val checkBox = jPanel.getComponent(0) as JCheckBox - checkBox.isSelected = selected + when (language) { + SupportedLanguage.Java -> TopButtonsPanelStrategy.toggleAllJavaCheckboxes(selected, project) + SupportedLanguage.Kotlin -> TopButtonsPanelStrategy.toggleAllKotlinCheckboxes(selected, project) } - project.service() - .setTestsSelected(if (selected) project.service().getTestCasePanels().size else 0) } /** * Removes all test cases from the cache and tool window UI. */ private fun removeAllTestCases() { - // Ask the user for the confirmation - val choice = JOptionPane.showConfirmDialog( - null, - PluginMessagesBundle.get("removeAllMessage"), - PluginMessagesBundle.get("confirmationTitle"), - JOptionPane.YES_NO_OPTION, - JOptionPane.QUESTION_MESSAGE, - ) - - // Cancel the operation if the user did not press "Yes" - if (choice == JOptionPane.NO_OPTION) return - - project.service().clear() + when (language) { + SupportedLanguage.Java -> TopButtonsPanelStrategy.removeAllJavaTestCases(project) + SupportedLanguage.Kotlin -> TopButtonsPanelStrategy.removeAllKotlinTestCases(project) + } } /** diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/strategies/TopButtonsPanelStrategy.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/strategies/TopButtonsPanelStrategy.kt new file mode 100644 index 000000000..07d8f88f2 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/strategies/TopButtonsPanelStrategy.kt @@ -0,0 +1,138 @@ +package org.jetbrains.research.testspark.display.strategies + +import com.intellij.openapi.components.service +import com.intellij.openapi.project.Project +import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle +import org.jetbrains.research.testspark.display.TestCasePanelFactory +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService +import javax.swing.JButton +import javax.swing.JCheckBox +import javax.swing.JLabel +import javax.swing.JOptionPane + +class TopButtonsPanelStrategy { + companion object { + fun toggleAllJavaCheckboxes(selected: Boolean, project: Project) { + project.service().getTestCasePanels().forEach { (_, jPanel) -> + val checkBox = jPanel.getComponent(0) as JCheckBox + checkBox.isSelected = selected + } + project.service() + .setTestsSelected( + if (selected) project.service().getTestCasePanels().size else 0, + ) + } + + fun toggleAllKotlinCheckboxes(selected: Boolean, project: Project) { + project.service().getTestCasePanels().forEach { (_, jPanel) -> + val checkBox = jPanel.getComponent(0) as JCheckBox + checkBox.isSelected = selected + } + project.service() + .setTestsSelected( + if (selected) project.service().getTestCasePanels().size else 0, + ) + } + + fun updateTopJavaLabels( + testCasePanelFactories: ArrayList, + testsSelectedLabel: JLabel, + testsSelectedText: String, + project: Project, + testsPassedLabel: JLabel, + testsPassedText: String, + runAllButton: JButton, + ) { + var numberOfPassedTests = 0 + for (testCasePanelFactory in testCasePanelFactories) { + if (testCasePanelFactory.isRemoved()) continue + val error = testCasePanelFactory.getError() + if ((error is String) && error.isEmpty()) { + numberOfPassedTests++ + } + } + testsSelectedLabel.text = String.format( + testsSelectedText, + project.service().getTestsSelected(), + project.service().getTestCasePanels().size, + ) + testsPassedLabel.text = + String.format( + testsPassedText, + numberOfPassedTests, + project.service().getTestCasePanels().size, + ) + runAllButton.isEnabled = false + for (testCasePanelFactory in testCasePanelFactories) { + runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled() + } + } + + fun updateTopKotlinLabels( + testCasePanelFactories: ArrayList, + testsSelectedLabel: JLabel, + testsSelectedText: String, + project: Project, + testsPassedLabel: JLabel, + testsPassedText: String, + runAllButton: JButton, + ) { + var numberOfPassedTests = 0 + for (testCasePanelFactory in testCasePanelFactories) { + if (testCasePanelFactory.isRemoved()) continue + val error = testCasePanelFactory.getError() + if ((error is String) && error.isEmpty()) { + numberOfPassedTests++ + } + } + testsSelectedLabel.text = String.format( + testsSelectedText, + project.service().getTestsSelected(), + project.service().getTestCasePanels().size, + ) + testsPassedLabel.text = + String.format( + testsPassedText, + numberOfPassedTests, + project.service().getTestCasePanels().size, + ) + runAllButton.isEnabled = false + for (testCasePanelFactory in testCasePanelFactories) { + runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled() + } + } + + fun removeAllJavaTestCases(project: Project) { + // Ask the user for the confirmation + val choice = JOptionPane.showConfirmDialog( + null, + PluginMessagesBundle.get("removeAllMessage"), + PluginMessagesBundle.get("confirmationTitle"), + JOptionPane.YES_NO_OPTION, + JOptionPane.QUESTION_MESSAGE, + ) + + // Cancel the operation if the user did not press "Yes" + if (choice == JOptionPane.NO_OPTION) return + + project.service().clear() + } + + fun removeAllKotlinTestCases(project: Project) { + // Ask the user for the confirmation + val choice = JOptionPane.showConfirmDialog( + null, + PluginMessagesBundle.get("removeAllMessage"), + PluginMessagesBundle.get("confirmationTitle"), + JOptionPane.YES_NO_OPTION, + JOptionPane.QUESTION_MESSAGE, + ) + + // Cancel the operation if the user did not press "Yes" + if (choice == JOptionPane.NO_OPTION) return + + project.service().clear() + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt index bcad7a834..dee6a2b0e 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt @@ -16,7 +16,7 @@ import com.intellij.ui.components.JBLabel import com.intellij.ui.components.JBScrollPane import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.services.EvoSuiteSettingsService -import org.jetbrains.research.testspark.services.TestCaseDisplayService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsState import java.awt.Color import java.awt.Dimension @@ -130,7 +130,7 @@ class CoverageHelper( * @param name name of the test to highlight */ private fun highlightInToolwindow(name: String) { - val testCaseDisplayService = project.service() + val testCaseDisplayService = project.service() testCaseDisplayService.highlightTestCase(name) } @@ -141,7 +141,7 @@ class CoverageHelper( * @param map map of mutant operations -> List of names of tests which cover the mutants */ private fun highlightMutantsInToolwindow(mutantName: String, map: HashMap>) { - val testCaseDisplayService = project.service() + val testCaseDisplayService = project.service() testCaseDisplayService.highlightCoveredMutants(map.getOrPut(mutantName) { ArrayList() }) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt deleted file mode 100644 index cf62202b3..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt +++ /dev/null @@ -1,204 +0,0 @@ -package org.jetbrains.research.testspark.helpers - -import com.github.javaparser.ParseProblemException -import com.github.javaparser.StaticJavaParser -import com.github.javaparser.ast.CompilationUnit -import com.github.javaparser.ast.body.MethodDeclaration -import com.github.javaparser.ast.visitor.VoidVisitorAdapter -import com.intellij.lang.java.JavaLanguage -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.openapi.project.Project -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiFile -import com.intellij.psi.PsiFileFactory -import com.intellij.psi.codeStyle.CodeStyleManager -import org.jetbrains.research.testspark.core.data.TestGenerationData -import java.io.File - -object JavaClassBuilderHelper { - /** - * Generates the code for a test class. - * - * @param className the name of the test class - * @param body the body of the test class - * @return the generated code as a string - */ - fun generateCode( - project: Project, - className: String, - body: String, - imports: Set, - packageString: String, - runWith: String, - otherInfo: String, - testGenerationData: TestGenerationData, - ): String { - var testFullText = printUpperPart(className, imports, packageString, runWith, otherInfo) - - // Add each test (exclude expected exception) - testFullText += body - - // close the test class - testFullText += "}" - - testFullText.replace("\r\n", "\n") - - /** - * for better readability and make the tests shorter, we reduce the number of line breaks: - * when we have three or more sequential \n, reduce it to two. - */ - return formatJavaCode(project, Regex("\n\n\n(\n)*").replace(testFullText, "\n\n"), testGenerationData) - } - - /** - * Returns the upper part of test suite (package name, imports, and test class name) as a string. - * - * @return the upper part of test suite (package name, imports, and test class name) as a string. - */ - private fun printUpperPart( - className: String, - imports: Set, - packageString: String, - runWith: String, - otherInfo: String, - ): String { - var testText = "" - - // Add package - if (packageString.isNotBlank()) { - testText += "package $packageString;\n" - } - - // add imports - imports.forEach { importedElement -> - testText += "$importedElement\n" - } - - testText += "\n" - - // add runWith if exists - if (runWith.isNotBlank()) { - testText += "@RunWith($runWith)\n" - } - // open the test class - testText += "public class $className {\n\n" - - // Add other presets (annotations, non-test functions) - if (otherInfo.isNotBlank()) { - testText += otherInfo - } - - return testText - } - - /** - * Finds the test method from a given class with the specified test case name. - * - * @param code The code of the class containing test methods. - * @return The test method as a string, including the "@Test" annotation. - */ - fun getTestMethodCodeFromClassWithTestCase(code: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(code) - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result += "\t" + method.toString().replace("\n", "\n\t") + "\n\n" - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - val upperCutCode = "\t@Test" + code.split("@Test").last() - var methodStarted = false - var balanceOfBrackets = 0 - for (symbol in upperCutCode) { - result += symbol - if (symbol == '{') { - methodStarted = true - balanceOfBrackets++ - } - if (symbol == '}') { - balanceOfBrackets-- - } - if (methodStarted && balanceOfBrackets == 0) { - break - } - } - return result + "\n" - } - } - - /** - * Retrieves the name of the test method from a given Java class with test cases. - * - * @param oldTestCaseName The old name of test case - * @param code The source code of the Java class with test cases. - * @return The name of the test method. If no test method is found, an empty string is returned. - */ - fun getTestMethodNameFromClassWithTestCase(oldTestCaseName: String, code: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(code) - - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result = method.nameAsString - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - return oldTestCaseName - } - } - - /** - * Retrieves the class name from the given test case code. - * - * @param code The test case code to extract the class name from. - * @return The class name extracted from the test case code. - */ - fun getClassFromTestCaseCode(code: String): String { - val pattern = Regex("public\\s+class\\s+(\\S+)\\s*\\{") - val matchResult = pattern.find(code) - matchResult ?: return "GeneratedTest" - val (className) = matchResult.destructured - return className - } - - /** - * Formats the given Java code using IntelliJ IDEA's code formatting rules. - * - * @param code The Java code to be formatted. - * @return The formatted Java code. - */ - fun formatJavaCode(project: Project, code: String, generatedTestData: TestGenerationData): String { - var result = "" - WriteCommandAction.runWriteCommandAction(project) { - val fileName = generatedTestData.resultPath + File.separatorChar + "Formatted.java" - // create a temporary PsiFile - val psiFile: PsiFile = PsiFileFactory.getInstance(project) - .createFileFromText( - fileName, - JavaLanguage.INSTANCE, - code, - ) - - CodeStyleManager.getInstance(project).reformat(psiFile) - - val document = PsiDocumentManager.getInstance(project).getDocument(psiFile) - result = document?.text ?: code - - File(fileName).delete() - } - - return result - } -} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt index b36fe381a..d10525087 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt @@ -12,15 +12,19 @@ import org.jetbrains.research.testspark.core.generation.llm.executeTestCaseModif import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestBodyPrinterFactory +import org.jetbrains.research.testspark.tools.TestSuiteParserFactory +import org.jetbrains.research.testspark.tools.TestsAssemblerFactory import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager -import org.jetbrains.research.testspark.tools.llm.generation.JUnitTestsAssembler import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieInfo import org.jetbrains.research.testspark.tools.llm.generation.grazie.GraziePlatform +import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFacePlatform import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIPlatform import java.net.HttpURLConnection import javax.swing.DefaultComboBoxModel @@ -67,6 +71,9 @@ object LLMHelper { if (platformSelector.selectedItem!!.toString() == settingsState.grazieName) { models = getGrazieModels() } + if (platformSelector.selectedItem!!.toString() == settingsState.huggingFaceName) { + models = getHuggingFaceModels() + } modelSelector.model = DefaultComboBoxModel(models) for (index in llmPlatforms.indices) { if (llmPlatforms[index].name == settingsState.openAIName && @@ -81,6 +88,12 @@ object LLMHelper { modelSelector.selectedItem = settingsState.grazieModel llmPlatforms[index].model = modelSelector.selectedItem!!.toString() } + if (llmPlatforms[index].name == settingsState.huggingFaceName && + llmPlatforms[index].name == platformSelector.selectedItem!!.toString() + ) { + modelSelector.selectedItem = settingsState.huggingFaceModel + llmPlatforms[index].model = modelSelector.selectedItem!!.toString() + } } modelSelector.isEnabled = true if (models.contentEquals(arrayOf(""))) modelSelector.isEnabled = false @@ -112,6 +125,12 @@ object LLMHelper { llmUserTokenField.text = settingsState.grazieToken llmPlatforms[index].token = settingsState.grazieToken } + if (llmPlatforms[index].name == settingsState.huggingFaceName && + llmPlatforms[index].name == platformSelector.selectedItem!!.toString() + ) { + llmUserTokenField.text = settingsState.huggingFaceToken + llmPlatforms[index].token = settingsState.huggingFaceToken + } } } @@ -185,8 +204,6 @@ object LLMHelper { if (isGrazieClassLoaded()) { platformSelector.model = DefaultComboBoxModel(llmPlatforms.map { it.name }.toTypedArray()) platformSelector.selectedItem = settingsState.currentLLMPlatformName - } else { - platformSelector.isEnabled = false } llmUserTokenField.toolTipText = LLMSettingsBundle.get("llmToken") @@ -202,7 +219,7 @@ object LLMHelper { * @return The list of LLMPlatforms. */ fun getLLLMPlatforms(): List { - return listOf(OpenAIPlatform(), GraziePlatform()) + return listOf(OpenAIPlatform(), GraziePlatform(), HuggingFacePlatform()) } /** @@ -230,7 +247,7 @@ object LLMHelper { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun testModificationRequest( - language: Language, + language: SupportedLanguage, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -244,13 +261,28 @@ object LLMHelper { return null } + val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + val testBodyPrinter = TestBodyPrinterFactory.create(language) + val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( + jUnitVersion, + language, + testBodyPrinter, + ) + + val testsAssembler = TestsAssemblerFactory.create( + indicator, + testGenerationOutput, + testSuiteParser, + jUnitVersion, + ) + val testSuite = executeTestCaseModificationRequest( language, testCase, task, indicator, requestManager, - testsAssembler = JUnitTestsAssembler(project, indicator, testGenerationOutput), + testsAssembler, errorMonitor, ) return testSuite @@ -328,4 +360,13 @@ object LLMHelper { arrayOf("") } } + + /** + * Retrieves the available HuggingFace models. + * + * @return an array of string representing the available HuggingFace models + */ + private fun getHuggingFaceModels(): Array { + return arrayOf("Meta-Llama-3-8B-Instruct", "Meta-Llama-3-70B-Instruct") + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt new file mode 100644 index 000000000..b20891ed4 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt @@ -0,0 +1,39 @@ +package org.jetbrains.research.testspark.helpers + +/** + * Interface for retrieving information from test class code. + */ +interface TestClassCodeAnalyzer { + /** + * Extracts the code of the first test method found in the given class code. + * + * @param classCode The code of the class containing test methods. + * @return The code of the first test method as a string, including the "@Test" annotation. + */ + fun extractFirstTestMethodCode(classCode: String): String + + /** + * Retrieves the name of the first test method found in the given class code. + * + * @param oldTestCaseName The old name of a test case + * @param classCode The source code of the class containing test methods. + * @return The name of the first test method. If no test method is found, an empty string is returned. + */ + fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String + + /** + * Retrieves the class name from the given test case code. + * + * @param code the test case code to extract the class name from + * @return the class name extracted from the test case code + */ + fun getClassFromTestCaseCode(code: String): String + + /** + * Return the right file name from the given test case code. + * + * @param code the test case code to extract the class name from + * @return the class name extracted from the test case code + */ + fun getFileNameFromTestCaseCode(code: String): String +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt new file mode 100644 index 000000000..7443b1664 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt @@ -0,0 +1,43 @@ +package org.jetbrains.research.testspark.helpers + +import com.intellij.openapi.project.Project +import org.jetbrains.research.testspark.core.data.TestGenerationData + +/** + * Interface for generating and formatting test class code. + */ +interface TestClassCodeGenerator { + /** + * Generates the code for a test class. + * + * @param project the current project + * @param className the name of the test class + * @param body the body of the test class + * @param imports the set of imports needed in the test class + * @param packageString the package declaration of the test class + * @param runWith the runWith annotation for the test class + * @param otherInfo any other additional information for the test class + * @param testGenerationData the data used for test generation + * @return the generated code as a string + */ + fun generateCode( + project: Project, + className: String, + body: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + testGenerationData: TestGenerationData, + ): String + + /** + * Formats the given Java code using IntelliJ IDEA's code formatting rules. + * + * @param project the current project + * @param code the Java code to be formatted + * @param generatedTestData the data used for generating the test + * @return the formatted Java code + */ + fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt new file mode 100644 index 000000000..f6f2fd0a9 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt @@ -0,0 +1,78 @@ +package org.jetbrains.research.testspark.helpers.java + +import com.github.javaparser.ParseProblemException +import com.github.javaparser.StaticJavaParser +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.body.MethodDeclaration +import com.github.javaparser.ast.visitor.VoidVisitorAdapter +import org.jetbrains.research.testspark.helpers.TestClassCodeAnalyzer + +object JavaTestClassCodeAnalyzer : TestClassCodeAnalyzer { + + override fun extractFirstTestMethodCode(classCode: String): String { + var result = "" + try { + val componentUnit: CompilationUnit = StaticJavaParser.parse(classCode) + object : VoidVisitorAdapter() { + override fun visit(method: MethodDeclaration, arg: Any?) { + super.visit(method, arg) + if (method.getAnnotationByName("Test").isPresent) { + result += "\t" + method.toString().replace("\n", "\n\t") + "\n\n" + } + } + }.visit(componentUnit, null) + + return result + } catch (e: ParseProblemException) { + val upperCutCode = "\t@Test" + classCode.split("@Test").last() + var methodStarted = false + var balanceOfBrackets = 0 + for (symbol in upperCutCode) { + result += symbol + if (symbol == '{') { + methodStarted = true + balanceOfBrackets++ + } + if (symbol == '}') { + balanceOfBrackets-- + } + if (methodStarted && balanceOfBrackets == 0) { + break + } + } + return result + "\n" + } + } + + override fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String { + var result = "" + try { + val componentUnit: CompilationUnit = StaticJavaParser.parse(classCode) + + object : VoidVisitorAdapter() { + override fun visit(method: MethodDeclaration, arg: Any?) { + super.visit(method, arg) + if (method.getAnnotationByName("Test").isPresent) { + result = method.nameAsString + } + } + }.visit(componentUnit, null) + + return result + } catch (e: ParseProblemException) { + return oldTestCaseName + } + } + + override fun getClassFromTestCaseCode(code: String): String { + val pattern = Regex("public\\s+class\\s+(\\S+)\\s*\\{") + val matchResult = pattern.find(code) + matchResult ?: return "GeneratedTest" + val (className) = matchResult.destructured + return className + } + + override fun getFileNameFromTestCaseCode(code: String): String { + return "${getClassFromTestCaseCode(code)}.java" + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt new file mode 100644 index 000000000..46c071d5f --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt @@ -0,0 +1,104 @@ +package org.jetbrains.research.testspark.helpers.java + +import com.intellij.lang.java.JavaLanguage +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.diagnostic.Logger +import com.intellij.openapi.project.Project +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiFile +import com.intellij.psi.PsiFileFactory +import com.intellij.psi.codeStyle.CodeStyleManager +import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.helpers.TestClassCodeGenerator +import java.io.File + +object JavaTestClassCodeGenerator : TestClassCodeGenerator { + + private val log = Logger.getInstance(this::class.java) + + override fun generateCode( + project: Project, + className: String, + body: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + testGenerationData: TestGenerationData, + ): String { + var testFullText = printUpperPart(className, imports, packageString, runWith, otherInfo) + + // Add each test (exclude expected exception) + testFullText += body + + // close the test class + testFullText += "}" + + testFullText.replace("\r\n", "\n") + + /** + * for better readability and make the tests shorter, we reduce the number of line breaks: + * when we have three or more sequential \n, reduce it to two. + */ + return formatCode(project, Regex("\n\n\n(?:\n)*").replace(testFullText, "\n\n"), testGenerationData) + } + + override fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String { + var result = "" + WriteCommandAction.runWriteCommandAction(project) { + val fileName = generatedTestData.resultPath + File.separatorChar + "Formatted.java" + // create a temporary PsiFile + val psiFile: PsiFile = PsiFileFactory.getInstance(project) + .createFileFromText( + fileName, + JavaLanguage.INSTANCE, + code, + ) + + CodeStyleManager.getInstance(project).reformat(psiFile) + + val document = PsiDocumentManager.getInstance(project).getDocument(psiFile) + result = document?.text ?: code + + File(fileName).delete() + } + + return result + } + + private fun printUpperPart( + className: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + ): String { + var testText = "" + + // Add package + if (packageString.isNotBlank()) { + testText += "package $packageString;\n" + } + + // add imports + imports.forEach { importedElement -> + testText += "$importedElement\n" + } + + testText += "\n" + + // add runWith if exists + if (runWith.isNotBlank()) { + testText += "@RunWith($runWith)\n" + } + // open the test class + testText += "public class $className {\n\n" + + // Add other presets (annotations, non-test functions) + if (otherInfo.isNotBlank()) { + testText += otherInfo + } + + return testText + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt new file mode 100644 index 000000000..b21a97dfd --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt @@ -0,0 +1,65 @@ +package org.jetbrains.research.testspark.helpers.kotlin + +import org.jetbrains.research.testspark.helpers.TestClassCodeAnalyzer + +object KotlinTestClassCodeAnalyzer : TestClassCodeAnalyzer { + + override fun extractFirstTestMethodCode(classCode: String): String { + val testMethods = StringBuilder() + val lines = classCode.lines() + + var methodStarted = false + var balanceOfBrackets = 0 + + for (line in lines) { + if (!methodStarted && line.contains("@Test")) { + methodStarted = true + testMethods.append(line).append("\n") + } else if (methodStarted) { + testMethods.append(line).append("\n") + for (char in line) { + if (char == '{') { + balanceOfBrackets++ + } else if (char == '}') { + balanceOfBrackets-- + } + } + if (balanceOfBrackets == 0) { + methodStarted = false + testMethods.append("\n") + } + } + } + + return testMethods.toString() + } + + override fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String { + val lines = classCode.lines() + var testMethodName = oldTestCaseName + + for (line in lines) { + if (line.contains("@Test")) { + val methodDeclarationLine = lines[lines.indexOf(line) + 1] + val matchResult = Regex("fun\\s+(\\w+)\\s*\\(").find(methodDeclarationLine) + if (matchResult != null) { + testMethodName = matchResult.groupValues[1] + } + break + } + } + return testMethodName + } + + override fun getClassFromTestCaseCode(code: String): String { + val pattern = Regex("class\\s+(\\S+)\\s*\\{") + val matchResult = pattern.find(code) + matchResult ?: return "GeneratedTest" + val (className) = matchResult.destructured + return className + } + + override fun getFileNameFromTestCaseCode(code: String): String { + return "${getClassFromTestCaseCode(code)}.kt" + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt new file mode 100644 index 000000000..eb10a7aa9 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt @@ -0,0 +1,101 @@ +package org.jetbrains.research.testspark.helpers.kotlin + +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.diagnostic.Logger +import com.intellij.openapi.project.Project +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiFile +import com.intellij.psi.PsiFileFactory +import com.intellij.psi.codeStyle.CodeStyleManager +import org.jetbrains.kotlin.idea.KotlinLanguage +import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.helpers.TestClassCodeGenerator +import java.io.File + +object KotlinTestClassCodeGenerator : TestClassCodeGenerator { + + private val log = Logger.getInstance(this::class.java) + + override fun generateCode( + project: Project, + className: String, + body: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + testGenerationData: TestGenerationData, + ): String { + log.debug("[KotlinClassBuilderHelper] Generate code for $className") + + var testFullText = + printUpperPart(className, imports, packageString, runWith, otherInfo) + + // Add each test (exclude expected exception) + testFullText += body + + // close the test class + testFullText += "}" + + testFullText.replace("\r\n", "\n") + + // Reduce the number of line breaks for better readability + return formatCode(project, Regex("\n\n\n(?:\n)*").replace(testFullText, "\n\n"), testGenerationData) + } + + override fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String { + var result = "" + WriteCommandAction.runWriteCommandAction(project) { + val fileName = generatedTestData.resultPath + File.separatorChar + "Formatted.kt" + // Create a temporary PsiFile + val psiFile: PsiFile = PsiFileFactory.getInstance(project) + .createFileFromText(fileName, KotlinLanguage.INSTANCE, code) + + CodeStyleManager.getInstance(project).reformat(psiFile) + + val document = PsiDocumentManager.getInstance(project).getDocument(psiFile) + result = document?.text ?: code + + File(fileName).delete() + } + log.info("Formatted result class: $result") + return result + } + + private fun printUpperPart( + className: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + ): String { + var testText = "" + + // Add package + if (packageString.isNotBlank()) { + testText += "package $packageString\n" + } + + // Add imports + imports.forEach { importedElement -> + testText += "$importedElement\n" + } + + testText += "\n" + + // Add runWith if exists + if (runWith.isNotBlank()) { + testText += "@RunWith($runWith::class)\n" + } + + // Open the test class + testText += "class $className {\n\n" + + // Add other presets (annotations, non-test functions) + if (otherInfo.isNotBlank()) { + testText += otherInfo + } + + return testText + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt index e3b11555a..6b257f421 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt @@ -1,425 +1,69 @@ package org.jetbrains.research.testspark.services -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.openapi.components.Service -import com.intellij.openapi.components.service -import com.intellij.openapi.fileChooser.FileChooser -import com.intellij.openapi.fileChooser.FileChooserDescriptor -import com.intellij.openapi.fileEditor.FileDocumentManager -import com.intellij.openapi.fileEditor.FileEditorManager -import com.intellij.openapi.fileEditor.OpenFileDescriptor -import com.intellij.openapi.fileEditor.TextEditor -import com.intellij.openapi.project.Project -import com.intellij.openapi.vfs.LocalFileSystem -import com.intellij.openapi.vfs.VirtualFile -import com.intellij.openapi.vfs.VirtualFileManager -import com.intellij.openapi.wm.ToolWindowManager -import com.intellij.psi.PsiClass -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiElementFactory -import com.intellij.psi.PsiJavaFile -import com.intellij.psi.PsiManager -import com.intellij.refactoring.suggested.startOffset +import com.intellij.psi.PsiFile import com.intellij.ui.EditorTextField -import com.intellij.ui.JBColor -import com.intellij.ui.components.JBScrollPane -import com.intellij.ui.content.Content -import com.intellij.ui.content.ContentFactory -import com.intellij.ui.content.ContentManager -import com.intellij.util.containers.stream -import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle -import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle import org.jetbrains.research.testspark.core.data.Report -import org.jetbrains.research.testspark.core.data.TestCase -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.data.UIContext -import org.jetbrains.research.testspark.display.TestCasePanelFactory -import org.jetbrains.research.testspark.display.TopButtonsPanelFactory -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper -import org.jetbrains.research.testspark.helpers.ReportHelper -import java.awt.BorderLayout -import java.awt.Color -import java.awt.Dimension -import java.io.File -import java.util.Locale -import javax.swing.Box -import javax.swing.BoxLayout -import javax.swing.JButton -import javax.swing.JCheckBox -import javax.swing.JOptionPane +import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import javax.swing.JPanel -import javax.swing.JSeparator -import javax.swing.SwingConstants -@Service(Service.Level.PROJECT) -class TestCaseDisplayService(private val project: Project) { - private var report: Report? = null - - private val unselectedTestCases = HashMap() - - private var mainPanel: JPanel = JPanel() - - private val topButtonsPanelFactory = TopButtonsPanelFactory(project) - - private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) - - private var allTestCasePanel: JPanel = JPanel() - - private var scrollPane: JBScrollPane = JBScrollPane( - allTestCasePanel, - JBScrollPane.VERTICAL_SCROLLBAR_ALWAYS, - JBScrollPane.HORIZONTAL_SCROLLBAR_NEVER, - ) - - private var testCasePanels: HashMap = HashMap() - - private var testsSelected: Int = 0 - - /** - * Default color for the editors in the tool window - */ - private var defaultEditorColor: Color? = null - - /** - * Content Manager to be able to add / remove tabs from tool window - */ - private var contentManager: ContentManager? = null - - /** - * Variable to keep reference to the coverage visualisation content - */ - private var content: Content? = null - - var uiContext: UIContext? = null - - init { - allTestCasePanel.layout = BoxLayout(allTestCasePanel, BoxLayout.Y_AXIS) - mainPanel.layout = BorderLayout() - - mainPanel.add(topButtonsPanelFactory.getPanel(), BorderLayout.NORTH) - mainPanel.add(scrollPane, BorderLayout.CENTER) - - applyButton.isOpaque = false - applyButton.isContentAreaFilled = false - mainPanel.add(applyButton, BorderLayout.SOUTH) - - applyButton.addActionListener { applyTests() } - } +interface TestCaseDisplayService { /** * Fill the panel with the generated test cases. Remove all previously shown test cases. * Add Tests and their names to a List of pairs (used for highlighting) */ - fun displayTestCases(report: Report, uiContext: UIContext, language: Language) { - this.report = report - this.uiContext = uiContext - - val editor = project.service().editor!! - - allTestCasePanel.removeAll() - testCasePanels.clear() - - addSeparator() - - // TestCasePanelFactories array - val testCasePanelFactories = arrayListOf() - - report.testCaseList.values.forEach { - val testCase = it - val testCasePanel = JPanel() - testCasePanel.layout = BorderLayout() - - // Add a checkbox to select the test - val checkbox = JCheckBox() - checkbox.isSelected = true - checkbox.addItemListener { - // Update the number of selected tests - testsSelected -= (1 - 2 * checkbox.isSelected.compareTo(false)) - - if (checkbox.isSelected) { - ReportHelper.selectTestCase(project, report, unselectedTestCases, testCase.id) - } else { - ReportHelper.unselectTestCase(project, report, unselectedTestCases, testCase.id) - } - - updateUI() - } - testCasePanel.add(checkbox, BorderLayout.WEST) - - val testCasePanelFactory = TestCasePanelFactory(project, language, testCase, editor, checkbox, uiContext, report) - testCasePanel.add(testCasePanelFactory.getUpperPanel(), BorderLayout.NORTH) - testCasePanel.add(testCasePanelFactory.getMiddlePanel(), BorderLayout.CENTER) - testCasePanel.add(testCasePanelFactory.getBottomPanel(), BorderLayout.SOUTH) - - testCasePanelFactories.add(testCasePanelFactory) - - testCasePanel.add(Box.createRigidArea(Dimension(12, 0)), BorderLayout.EAST) - - // Add panel to parent panel - testCasePanel.maximumSize = Dimension(Short.MAX_VALUE.toInt(), Short.MAX_VALUE.toInt()) - allTestCasePanel.add(testCasePanel) - addSeparator() - testCasePanels[testCase.testName] = testCasePanel - } - - // Update the number of selected tests (all tests are selected by default) - testsSelected = testCasePanels.size - - topButtonsPanelFactory.setTestCasePanelFactoriesArray(testCasePanelFactories) - topButtonsPanelFactory.updateTopLabels() - - createToolWindowTab() - } + fun displayTestCases(report: Report, uiContext: UIContext, language: SupportedLanguage) /** * Adds a separator to the allTestCasePanel. */ - private fun addSeparator() { - allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) - allTestCasePanel.add(JSeparator(SwingConstants.HORIZONTAL)) - allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) - } + fun addSeparator() /** * Highlight the mini-editor in the tool window whose name corresponds with the name of the test provided * * @param name name of the test whose editor should be highlighted */ - fun highlightTestCase(name: String) { - val myPanel = testCasePanels[name] ?: return - openToolWindowTab() - scrollToPanel(myPanel) - - val editor = getEditor(name) ?: return - val settingsProjectState = project.service().state - val highlightColor = - JBColor( - PluginSettingsBundle.get("colorName"), - Color( - settingsProjectState.colorRed, - settingsProjectState.colorGreen, - settingsProjectState.colorBlue, - 30, - ), - ) - if (editor.background.equals(highlightColor)) return - defaultEditorColor = editor.background - editor.background = highlightColor - returnOriginalEditorBackground(editor) - } + fun highlightTestCase(name: String) /** * Method to open the toolwindow tab with generated tests if not already open. */ - private fun openToolWindowTab() { - val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") - contentManager = toolWindowManager!!.contentManager - if (content != null) { - toolWindowManager.show() - toolWindowManager.contentManager.setSelectedContent(content!!) - } - } + fun openToolWindowTab() /** * Scrolls to the highlighted panel. * * @param myPanel the panel to scroll to */ - private fun scrollToPanel(myPanel: JPanel) { - var sum = 0 - for (component in allTestCasePanel.components) { - if (component == myPanel) { - break - } else { - sum += component.height - } - } - val scroll = scrollPane.verticalScrollBar - scroll.value = (scroll.minimum + scroll.maximum) * sum / allTestCasePanel.height - } + fun scrollToPanel(myPanel: JPanel) /** * Removes all coverage highlighting from the editor. */ - private fun removeAllHighlights() { - project.service().editor?.markupModel?.removeAllHighlighters() - } + fun removeAllHighlights() /** * Reset the provided editors color to the default (initial) one after 10 seconds * @param editor the editor whose color to change */ - private fun returnOriginalEditorBackground(editor: EditorTextField) { - Thread { - Thread.sleep(10000) - editor.background = defaultEditorColor - }.start() - } + fun returnOriginalEditorBackground(editor: EditorTextField) /** * Highlight a range of editors * @param names list of test names to pass to highlight function */ - fun highlightCoveredMutants(names: List) { - names.forEach { - highlightTestCase(it) - } - } + fun highlightCoveredMutants(names: List) /** * Show a dialog where the user can select what test class the tests should be applied to, * and apply the selected tests to the test class. */ - private fun applyTests() { - // Filter the selected test cases - val selectedTestCasePanels = testCasePanels.filter { (it.value.getComponent(0) as JCheckBox).isSelected } - val selectedTestCases = selectedTestCasePanels.map { it.key } - - // Get the test case components (source code of the tests) - val testCaseComponents = selectedTestCases - .map { getEditor(it)!! } - .map { it.document.text } - - // Descriptor for choosing folders and java files - val descriptor = FileChooserDescriptor(true, true, false, false, false, false) - - // Apply filter with folders and java files with main class - WriteCommandAction.runWriteCommandAction(project) { - descriptor.withFileFilter { file -> - file.isDirectory || ( - file.extension?.lowercase(Locale.getDefault()) == "java" && ( - PsiManager.getInstance(project).findFile(file!!) as PsiJavaFile - ).classes.stream().map { it.name } - .toArray() - .contains( - ( - PsiManager.getInstance(project) - .findFile(file) as PsiJavaFile - ).name.removeSuffix(".java"), - ) - ) - } - } - - val fileChooser = FileChooser.chooseFiles( - descriptor, - project, - LocalFileSystem.getInstance().findFileByPath(project.basePath!!), - ) - - /** - * Cancel button pressed - */ - if (fileChooser.isEmpty()) return - - /** - * Chosen files by user - */ - val chosenFile = fileChooser[0] - - /** - * Virtual file of a final java file - */ - var virtualFile: VirtualFile? = null - - /** - * PsiClass of a final java file - */ - var psiClass: PsiClass? = null - - /** - * PsiJavaFile of a final java file - */ - var psiJavaFile: PsiJavaFile? = null - - if (chosenFile.isDirectory) { - // Input new file data - var className: String - var fileName: String - var filePath: String - // Waiting for correct file name input - while (true) { - val jOptionPane = - JOptionPane.showInputDialog( - null, - PluginLabelsBundle.get("optionPaneMessage"), - PluginLabelsBundle.get("optionPaneTitle"), - JOptionPane.PLAIN_MESSAGE, - null, - null, - null, - ) - - // Cancel button pressed - jOptionPane ?: return - - // Get class name from user - className = jOptionPane as String - - // Set file name and file path - fileName = "${className.split('.')[0]}.java" - filePath = "${chosenFile.path}/$fileName" - - // Check the correctness of a class name - if (!Regex("[A-Z][a-zA-Z0-9]*(.java)?").matches(className)) { - showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage")) - continue - } - - // Check the existence of a file with this name - if (File(filePath).exists()) { - showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage")) - continue - } - break - } - - // Create new file and set services of this file - WriteCommandAction.runWriteCommandAction(project) { - chosenFile.createChildData(null, fileName) - virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!! - psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) - psiClass = PsiElementFactory.getInstance(project).createClass(className.split(".")[0]) + fun applyTests() - if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { - psiClass!!.modifierList!!.addAnnotation("RunWith(${uiContext!!.testGenerationOutput.runWith})") - } - - psiJavaFile!!.add(psiClass!!) - } - } else { - // Set services of the chosen file - virtualFile = chosenFile - psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) - psiClass = psiJavaFile!!.classes[ - psiJavaFile!!.classes.stream().map { it.name }.toArray() - .indexOf(psiJavaFile!!.name.removeSuffix(".java")), - ] - } - - // Add tests to the file - WriteCommandAction.runWriteCommandAction(project) { - appendTestsToClass(testCaseComponents, psiClass!!, psiJavaFile!!) - } - - // Remove the selected test cases from the cache and the tool window UI - removeSelectedTestCases(selectedTestCasePanels) - - // Open the file after adding - FileEditorManager.getInstance(project).openTextEditor( - OpenFileDescriptor(project, virtualFile!!), - true, - ) - } - - private fun showErrorWindow(message: String) { - JOptionPane.showMessageDialog( - null, - message, - PluginLabelsBundle.get("errorWindowTitle"), - JOptionPane.ERROR_MESSAGE, - ) - } + fun showErrorWindow(message: String) /** * Retrieve the editor corresponding to a particular test case @@ -427,11 +71,7 @@ class TestCaseDisplayService(private val project: Project) { * @param testCaseName the name of the test case * @return the editor corresponding to the test case, or null if it does not exist */ - fun getEditor(testCaseName: String): EditorTextField? { - val middlePanelComponent = testCasePanels[testCaseName]?.getComponent(2) ?: return null - val middlePanel = middlePanelComponent as JPanel - return (middlePanel.getComponent(1) as JBScrollPane).viewport.view as EditorTextField - } + fun getEditor(testCaseName: String): EditorTextField? /** * Append the provided test cases to the provided class. @@ -440,107 +80,23 @@ class TestCaseDisplayService(private val project: Project) { * @param selectedClass the class which the test cases should be appended to * @param outputFile the output file for tests */ - private fun appendTestsToClass(testCaseComponents: List, selectedClass: PsiClass, outputFile: PsiJavaFile) { - // block document - PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument( - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!, - ) - - // insert tests to a code - testCaseComponents.reversed().forEach { - val testMethodCode = - JavaClassBuilderHelper.getTestMethodCodeFromClassWithTestCase( - JavaClassBuilderHelper.formatJavaCode( - project, - it.replace("\r\n", "\n") - .replace("verifyException(", "// verifyException("), - uiContext!!.testGenerationOutput, - ), - ) - // Fix Windows line separators - .replace("\r\n", "\n") - - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - selectedClass.rBrace!!.textRange.startOffset, - testMethodCode, - ) - } - - // insert other info to a code - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - selectedClass.rBrace!!.textRange.startOffset, - uiContext!!.testGenerationOutput.otherInfo + "\n", - ) - - // insert imports to a code - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - outputFile.importList?.startOffset ?: outputFile.packageStatement?.startOffset ?: 0, - uiContext!!.testGenerationOutput.importsCode.joinToString("\n") + "\n\n", - ) - - // insert package to a code - outputFile.packageStatement ?: PsiDocumentManager.getInstance(project).getDocument(outputFile)!! - .insertString( - 0, - if (uiContext!!.testGenerationOutput.packageLine.isEmpty()) { - "" - } else { - "package ${uiContext!!.testGenerationOutput.packageLine};\n\n" - }, - ) - } + fun appendTestsToClass(testCaseComponents: List, selectedClass: PsiClassWrapper, outputFile: PsiFile) /** * Utility function that returns the editor for a specific file url, * in case it is opened in the IDE */ - fun updateEditorForFileUrl(fileUrl: String) { - val documentManager = FileDocumentManager.getInstance() - // https://intellij-support.jetbrains.com/hc/en-us/community/posts/360004480599/comments/360000703299 - FileEditorManager.getInstance(project).selectedEditors.map { it as TextEditor }.map { it.editor }.map { - val currentFile = documentManager.getFile(it.document) - if (currentFile != null) { - if (currentFile.presentableUrl == fileUrl) { - project.service().editor = it - } - } - } - } + fun updateEditorForFileUrl(fileUrl: String) /** * Creates a new toolWindow tab for the coverage visualisation. */ - private fun createToolWindowTab() { - // Remove generated tests tab from content manager if necessary - val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") - contentManager = toolWindowManager!!.contentManager - if (content != null) { - contentManager!!.removeContent(content!!, true) - } - - // If there is no generated tests tab, make it - val contentFactory: ContentFactory = ContentFactory.getInstance() - content = contentFactory.createContent( - mainPanel, - PluginLabelsBundle.get("generatedTests"), - true, - ) - contentManager!!.addContent(content!!) - - // Focus on generated tests tab and open toolWindow if not opened already - contentManager!!.setSelectedContent(content!!) - toolWindowManager.show() - } + fun createToolWindowTab() /** * Closes the tool window and destroys the content of the tab. */ - private fun closeToolWindow() { - contentManager?.removeContent(content!!, true) - ToolWindowManager.getInstance(project).getToolWindow("TestSpark")?.hide() - val coverageVisualisationService = project.service() - coverageVisualisationService.closeToolWindowTab() - } + fun closeToolWindow() /** * Removes the selected tests from the cache, removes all the highlights from the editor and closes the tool window. @@ -549,37 +105,16 @@ class TestCaseDisplayService(private val project: Project) { * * @param selectedTestCasePanels the panels of the selected tests */ - private fun removeSelectedTestCases(selectedTestCasePanels: Map) { - selectedTestCasePanels.forEach { removeTestCase(it.key) } - removeAllHighlights() - closeToolWindow() - } - - fun clear() { - // Remove the tests - val testCasePanelsToRemove = testCasePanels.toMap() - removeSelectedTestCases(testCasePanelsToRemove) + fun removeSelectedTestCases(selectedTestCasePanels: Map) - topButtonsPanelFactory.clear() - } + fun clear() /** * A helper method to remove a test case from the cache and from the UI. * * @param testCaseName the name of the test */ - fun removeTestCase(testCaseName: String) { - // Update the number of selected test cases if necessary - if ((testCasePanels[testCaseName]!!.getComponent(0) as JCheckBox).isSelected) { - testsSelected-- - } - - // Remove the test panel from the UI - allTestCasePanel.remove(testCasePanels[testCaseName]) - - // Remove the test panel - testCasePanels.remove(testCaseName) - } + fun removeTestCase(testCaseName: String) /** * Updates the user interface of the tool window. @@ -589,36 +124,26 @@ class TestCaseDisplayService(private val project: Project) { * of the topButtonsPanel object. It also checks if there are no more tests remaining * and closes the tool window if that is the case. */ - fun updateUI() { - // Update the UI of the tool window tab - allTestCasePanel.updateUI() - - topButtonsPanelFactory.updateTopLabels() - - // If no more tests are remaining, close the tool window - if (testCasePanels.size == 0) closeToolWindow() - } + fun updateUI() /** * Retrieves the list of test case panels. * * @return The list of test case panels. */ - fun getTestCasePanels() = testCasePanels + fun getTestCasePanels(): HashMap /** * Retrieves the currently selected tests. * * @return The list of tests currently selected. */ - fun getTestsSelected() = testsSelected + fun getTestsSelected(): Int /** * Sets the number of tests selected. * * @param testsSelected The number of tests selected. */ - fun setTestsSelected(testsSelected: Int) { - this.testsSelected = testsSelected - } + fun setTestsSelected(testsSelected: Int) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt new file mode 100644 index 000000000..0dbc5009c --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt @@ -0,0 +1,544 @@ +package org.jetbrains.research.testspark.services.java + +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.components.Service +import com.intellij.openapi.components.service +import com.intellij.openapi.fileChooser.FileChooser +import com.intellij.openapi.fileChooser.FileChooserDescriptor +import com.intellij.openapi.fileEditor.FileDocumentManager +import com.intellij.openapi.fileEditor.FileEditorManager +import com.intellij.openapi.fileEditor.OpenFileDescriptor +import com.intellij.openapi.fileEditor.TextEditor +import com.intellij.openapi.project.Project +import com.intellij.openapi.vfs.LocalFileSystem +import com.intellij.openapi.vfs.VirtualFile +import com.intellij.openapi.vfs.VirtualFileManager +import com.intellij.openapi.wm.ToolWindowManager +import com.intellij.psi.PsiClass +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiElementFactory +import com.intellij.psi.PsiFile +import com.intellij.psi.PsiJavaFile +import com.intellij.psi.PsiManager +import com.intellij.refactoring.suggested.startOffset +import com.intellij.ui.EditorTextField +import com.intellij.ui.JBColor +import com.intellij.ui.components.JBScrollPane +import com.intellij.ui.content.Content +import com.intellij.ui.content.ContentFactory +import com.intellij.ui.content.ContentManager +import com.intellij.util.containers.stream +import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle +import org.jetbrains.research.testspark.core.data.Report +import org.jetbrains.research.testspark.core.data.TestCase +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.data.UIContext +import org.jetbrains.research.testspark.display.TestCasePanelFactory +import org.jetbrains.research.testspark.display.TopButtonsPanelFactory +import org.jetbrains.research.testspark.helpers.ReportHelper +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeGenerator +import org.jetbrains.research.testspark.java.JavaPsiClassWrapper +import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper +import org.jetbrains.research.testspark.services.CoverageVisualisationService +import org.jetbrains.research.testspark.services.EditorService +import org.jetbrains.research.testspark.services.PluginSettingsService +import org.jetbrains.research.testspark.services.TestCaseDisplayService +import java.awt.BorderLayout +import java.awt.Color +import java.awt.Dimension +import java.io.File +import java.util.Locale +import javax.swing.Box +import javax.swing.BoxLayout +import javax.swing.JButton +import javax.swing.JCheckBox +import javax.swing.JOptionPane +import javax.swing.JPanel +import javax.swing.JSeparator +import javax.swing.SwingConstants + +@Service(Service.Level.PROJECT) +class JavaTestCaseDisplayService(private val project: Project) : TestCaseDisplayService { + private var report: Report? = null + + private val unselectedTestCases = HashMap() + + private var mainPanel: JPanel = JPanel() + + private val topButtonsPanelFactory = TopButtonsPanelFactory(project, SupportedLanguage.Java) + + private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) + + private var allTestCasePanel: JPanel = JPanel() + + private var scrollPane: JBScrollPane = JBScrollPane( + allTestCasePanel, + JBScrollPane.VERTICAL_SCROLLBAR_ALWAYS, + JBScrollPane.HORIZONTAL_SCROLLBAR_NEVER, + ) + + private var testCasePanels: HashMap = HashMap() + + private var testsSelected: Int = 0 + + /** + * Default color for the editors in the tool window + */ + private var defaultEditorColor: Color? = null + + /** + * Content Manager to be able to add / remove tabs from tool window + */ + private var contentManager: ContentManager? = null + + /** + * Variable to keep reference to the coverage visualisation content + */ + private var content: Content? = null + + var uiContext: UIContext? = null + + init { + allTestCasePanel.layout = BoxLayout(allTestCasePanel, BoxLayout.Y_AXIS) + mainPanel.layout = BorderLayout() + + mainPanel.add(topButtonsPanelFactory.getPanel(), BorderLayout.NORTH) + mainPanel.add(scrollPane, BorderLayout.CENTER) + + applyButton.isOpaque = false + applyButton.isContentAreaFilled = false + mainPanel.add(applyButton, BorderLayout.SOUTH) + + applyButton.addActionListener { applyTests() } + } + + override fun displayTestCases(report: Report, uiContext: UIContext, language: SupportedLanguage) { + this.report = report + this.uiContext = uiContext + + val editor = project.service().editor!! + + allTestCasePanel.removeAll() + testCasePanels.clear() + + addSeparator() + + // TestCasePanelFactories array + val testCasePanelFactories = arrayListOf() + + report.testCaseList.values.forEach { + val testCase = it + val testCasePanel = JPanel() + testCasePanel.layout = BorderLayout() + + // Add a checkbox to select the test + val checkbox = JCheckBox() + checkbox.isSelected = true + checkbox.addItemListener { + // Update the number of selected tests + testsSelected -= (1 - 2 * checkbox.isSelected.compareTo(false)) + + if (checkbox.isSelected) { + ReportHelper.selectTestCase(project, report, unselectedTestCases, testCase.id) + } else { + ReportHelper.unselectTestCase(project, report, unselectedTestCases, testCase.id) + } + + updateUI() + } + testCasePanel.add(checkbox, BorderLayout.WEST) + + val testCasePanelFactory = + TestCasePanelFactory(project, language, testCase, editor, checkbox, uiContext, report) + testCasePanel.add(testCasePanelFactory.getUpperPanel(), BorderLayout.NORTH) + testCasePanel.add(testCasePanelFactory.getMiddlePanel(), BorderLayout.CENTER) + testCasePanel.add(testCasePanelFactory.getBottomPanel(), BorderLayout.SOUTH) + + testCasePanelFactories.add(testCasePanelFactory) + + testCasePanel.add(Box.createRigidArea(Dimension(12, 0)), BorderLayout.EAST) + + // Add panel to parent panel + testCasePanel.maximumSize = Dimension(Short.MAX_VALUE.toInt(), Short.MAX_VALUE.toInt()) + allTestCasePanel.add(testCasePanel) + addSeparator() + testCasePanels[testCase.testName] = testCasePanel + } + + // Update the number of selected tests (all tests are selected by default) + testsSelected = testCasePanels.size + + topButtonsPanelFactory.setTestCasePanelFactoriesArray(testCasePanelFactories) + topButtonsPanelFactory.updateTopLabels() + + createToolWindowTab() + } + + override fun addSeparator() { + allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) + allTestCasePanel.add(JSeparator(SwingConstants.HORIZONTAL)) + allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) + } + + override fun highlightTestCase(name: String) { + val myPanel = testCasePanels[name] ?: return + openToolWindowTab() + scrollToPanel(myPanel) + + val editor = getEditor(name) ?: return + val settingsProjectState = project.service().state + val highlightColor = + JBColor( + PluginSettingsBundle.get("colorName"), + Color( + settingsProjectState.colorRed, + settingsProjectState.colorGreen, + settingsProjectState.colorBlue, + 30, + ), + ) + if (editor.background.equals(highlightColor)) return + defaultEditorColor = editor.background + editor.background = highlightColor + returnOriginalEditorBackground(editor) + } + + override fun openToolWindowTab() { + val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") + contentManager = toolWindowManager!!.contentManager + if (content != null) { + toolWindowManager.show() + toolWindowManager.contentManager.setSelectedContent(content!!) + } + } + + override fun scrollToPanel(myPanel: JPanel) { + var sum = 0 + for (component in allTestCasePanel.components) { + if (component == myPanel) { + break + } else { + sum += component.height + } + } + val scroll = scrollPane.verticalScrollBar + scroll.value = (scroll.minimum + scroll.maximum) * sum / allTestCasePanel.height + } + + override fun removeAllHighlights() { + project.service().editor?.markupModel?.removeAllHighlighters() + } + + override fun returnOriginalEditorBackground(editor: EditorTextField) { + Thread { + Thread.sleep(10000) + editor.background = defaultEditorColor + }.start() + } + + override fun highlightCoveredMutants(names: List) { + names.forEach { + highlightTestCase(it) + } + } + + override fun applyTests() { + // Filter the selected test cases + val selectedTestCasePanels = testCasePanels.filter { (it.value.getComponent(0) as JCheckBox).isSelected } + val selectedTestCases = selectedTestCasePanels.map { it.key } + + // Get the test case components (source code of the tests) + val testCaseComponents = selectedTestCases + .map { getEditor(it)!! } + .map { it.document.text } + + // Descriptor for choosing folders and java files + val descriptor = FileChooserDescriptor(true, true, false, false, false, false) + + // Apply filter with folders and java files with main class + WriteCommandAction.runWriteCommandAction(project) { + descriptor.withFileFilter { file -> + file.isDirectory || ( + file.extension?.lowercase(Locale.getDefault()) == "java" && ( + PsiManager.getInstance(project).findFile(file!!) as PsiJavaFile + ).classes.stream().map { it.name } + .toArray() + .contains( + ( + PsiManager.getInstance(project) + .findFile(file) as PsiJavaFile + ).name.removeSuffix(".java"), + ) + ) + } + } + + val fileChooser = FileChooser.chooseFiles( + descriptor, + project, + LocalFileSystem.getInstance().findFileByPath(project.basePath!!), + ) + + /** + * Cancel button pressed + */ + if (fileChooser.isEmpty()) return + + /** + * Chosen files by user + */ + val chosenFile = fileChooser[0] + + /** + * Virtual file of a final java file + */ + var virtualFile: VirtualFile? = null + + /** + * PsiClass of a final java file + */ + var psiClass: PsiClass? = null + + /** + * PsiJavaFile of a final java file + */ + var psiJavaFile: PsiJavaFile? = null + + if (chosenFile.isDirectory) { + // Input new file data + var className: String + var fileName: String + var filePath: String + // Waiting for correct file name input + while (true) { + val jOptionPane = + JOptionPane.showInputDialog( + null, + PluginLabelsBundle.get("optionPaneMessage"), + PluginLabelsBundle.get("optionPaneTitle"), + JOptionPane.PLAIN_MESSAGE, + null, + null, + null, + ) + + // Cancel button pressed + jOptionPane ?: return + + // Get class name from user + className = jOptionPane as String + + // Set file name and file path + fileName = "${className.split('.')[0]}.java" + filePath = "${chosenFile.path}/$fileName" + + // Check the correctness of a class name + if (!Regex("[A-Z][a-zA-Z0-9]*(.java)?").matches(className)) { + showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage")) + continue + } + + // Check the existence of a file with this name + if (File(filePath).exists()) { + showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage")) + continue + } + break + } + + // Create new file and set services of this file + WriteCommandAction.runWriteCommandAction(project) { + chosenFile.createChildData(null, fileName) + virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!! + psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) + psiClass = PsiElementFactory.getInstance(project).createClass(className.split(".")[0]) + + if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { + psiClass!!.modifierList!!.addAnnotation("RunWith(${uiContext!!.testGenerationOutput.runWith})") + } + + psiJavaFile!!.add(psiClass!!) + } + } else { + // Set services of the chosen file + virtualFile = chosenFile + psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) + psiClass = psiJavaFile!!.classes[ + psiJavaFile!!.classes.stream().map { it.name }.toArray() + .indexOf(psiJavaFile!!.name.removeSuffix(".java")), + ] + } + + // Add tests to the file + WriteCommandAction.runWriteCommandAction(project) { + appendTestsToClass(testCaseComponents, JavaPsiClassWrapper(psiClass!!), psiJavaFile!!) + } + + // Remove the selected test cases from the cache and the tool window UI + removeSelectedTestCases(selectedTestCasePanels) + + // Open the file after adding + FileEditorManager.getInstance(project).openTextEditor( + OpenFileDescriptor(project, virtualFile!!), + true, + ) + } + + override fun showErrorWindow(message: String) { + JOptionPane.showMessageDialog( + null, + message, + PluginLabelsBundle.get("errorWindowTitle"), + JOptionPane.ERROR_MESSAGE, + ) + } + + override fun getEditor(testCaseName: String): EditorTextField? { + val middlePanelComponent = testCasePanels[testCaseName]?.getComponent(2) ?: return null + val middlePanel = middlePanelComponent as JPanel + return (middlePanel.getComponent(1) as JBScrollPane).viewport.view as EditorTextField + } + + override fun appendTestsToClass( + testCaseComponents: List, + selectedClass: PsiClassWrapper, + outputFile: PsiFile, + ) { + // block document + PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument( + PsiDocumentManager.getInstance(project).getDocument(outputFile as PsiJavaFile)!!, + ) + + // insert tests to a code + testCaseComponents.reversed().forEach { + val testMethodCode = + JavaTestClassCodeAnalyzer.extractFirstTestMethodCode( + JavaTestClassCodeGenerator.formatCode( + project, + it.replace("\r\n", "\n") + .replace("verifyException(", "// verifyException("), + uiContext!!.testGenerationOutput, + ), + ) + // Fix Windows line separators + .replace("\r\n", "\n") + + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + selectedClass.rBrace!!, + testMethodCode, + ) + } + + // insert other info to a code + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + selectedClass.rBrace!!, + uiContext!!.testGenerationOutput.otherInfo + "\n", + ) + + // insert imports to a code + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + outputFile.importList?.startOffset ?: outputFile.packageStatement?.startOffset ?: 0, + uiContext!!.testGenerationOutput.importsCode.joinToString("\n") + "\n\n", + ) + + // insert package to a code + outputFile.packageStatement ?: PsiDocumentManager.getInstance(project).getDocument(outputFile)!! + .insertString( + 0, + if (uiContext!!.testGenerationOutput.packageName.isEmpty()) { + "" + } else { + "package ${uiContext!!.testGenerationOutput.packageName};\n\n" + }, + ) + } + + override fun updateEditorForFileUrl(fileUrl: String) { + val documentManager = FileDocumentManager.getInstance() + // https://intellij-support.jetbrains.com/hc/en-us/community/posts/360004480599/comments/360000703299 + FileEditorManager.getInstance(project).selectedEditors.map { it as TextEditor }.map { it.editor }.map { + val currentFile = documentManager.getFile(it.document) + if (currentFile != null) { + if (currentFile.presentableUrl == fileUrl) { + project.service().editor = it + } + } + } + } + + override fun createToolWindowTab() { + // Remove generated tests tab from content manager if necessary + val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") + contentManager = toolWindowManager!!.contentManager + if (content != null) { + contentManager!!.removeContent(content!!, true) + } + + // If there is no generated tests tab, make it + val contentFactory: ContentFactory = ContentFactory.getInstance() + content = contentFactory.createContent( + mainPanel, + PluginLabelsBundle.get("generatedTests"), + true, + ) + contentManager!!.addContent(content!!) + + // Focus on generated tests tab and open toolWindow if not opened already + contentManager!!.setSelectedContent(content!!) + toolWindowManager.show() + } + + override fun closeToolWindow() { + contentManager?.removeContent(content!!, true) + ToolWindowManager.getInstance(project).getToolWindow("TestSpark")?.hide() + val coverageVisualisationService = project.service() + coverageVisualisationService.closeToolWindowTab() + } + + override fun removeSelectedTestCases(selectedTestCasePanels: Map) { + selectedTestCasePanels.forEach { removeTestCase(it.key) } + removeAllHighlights() + closeToolWindow() + } + + override fun clear() { + // Remove the tests + val testCasePanelsToRemove = testCasePanels.toMap() + removeSelectedTestCases(testCasePanelsToRemove) + + topButtonsPanelFactory.clear() + } + + override fun removeTestCase(testCaseName: String) { + // Update the number of selected test cases if necessary + if ((testCasePanels[testCaseName]!!.getComponent(0) as JCheckBox).isSelected) { + testsSelected-- + } + + // Remove the test panel from the UI + allTestCasePanel.remove(testCasePanels[testCaseName]) + + // Remove the test panel + testCasePanels.remove(testCaseName) + } + + override fun updateUI() { + // Update the UI of the tool window tab + allTestCasePanel.updateUI() + + topButtonsPanelFactory.updateTopLabels() + + // If no more tests are remaining, close the tool window + if (testCasePanels.size == 0) closeToolWindow() + } + + override fun getTestCasePanels() = testCasePanels + + override fun getTestsSelected() = testsSelected + + override fun setTestsSelected(testsSelected: Int) { + this.testsSelected = testsSelected + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt new file mode 100644 index 000000000..a80952747 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt @@ -0,0 +1,553 @@ +package org.jetbrains.research.testspark.services.kotlin + +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.components.Service +import com.intellij.openapi.components.service +import com.intellij.openapi.fileChooser.FileChooser +import com.intellij.openapi.fileChooser.FileChooserDescriptor +import com.intellij.openapi.fileEditor.FileDocumentManager +import com.intellij.openapi.fileEditor.FileEditorManager +import com.intellij.openapi.fileEditor.OpenFileDescriptor +import com.intellij.openapi.fileEditor.TextEditor +import com.intellij.openapi.project.Project +import com.intellij.openapi.vfs.LocalFileSystem +import com.intellij.openapi.vfs.VirtualFile +import com.intellij.openapi.vfs.VirtualFileManager +import com.intellij.openapi.wm.ToolWindowManager +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiFile +import com.intellij.psi.PsiJavaFile +import com.intellij.psi.PsiManager +import com.intellij.refactoring.suggested.endOffset +import com.intellij.refactoring.suggested.startOffset +import com.intellij.ui.EditorTextField +import com.intellij.ui.JBColor +import com.intellij.ui.components.JBScrollPane +import com.intellij.ui.content.Content +import com.intellij.ui.content.ContentFactory +import com.intellij.ui.content.ContentManager +import com.intellij.util.containers.stream +import org.jetbrains.kotlin.psi.KtClass +import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtPsiFactory +import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle +import org.jetbrains.research.testspark.core.data.Report +import org.jetbrains.research.testspark.core.data.TestCase +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.data.UIContext +import org.jetbrains.research.testspark.display.TestCasePanelFactory +import org.jetbrains.research.testspark.display.TopButtonsPanelFactory +import org.jetbrains.research.testspark.helpers.ReportHelper +import org.jetbrains.research.testspark.helpers.kotlin.KotlinClassBuilderHelper +import org.jetbrains.research.testspark.kotlin.KotlinPsiClassWrapper +import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper +import org.jetbrains.research.testspark.services.CoverageVisualisationService +import org.jetbrains.research.testspark.services.EditorService +import org.jetbrains.research.testspark.services.PluginSettingsService +import org.jetbrains.research.testspark.services.TestCaseDisplayService +import java.awt.BorderLayout +import java.awt.Color +import java.awt.Dimension +import java.io.File +import java.util.Locale +import javax.swing.Box +import javax.swing.BoxLayout +import javax.swing.JButton +import javax.swing.JCheckBox +import javax.swing.JOptionPane +import javax.swing.JPanel +import javax.swing.JSeparator +import javax.swing.SwingConstants + +@Service(Service.Level.PROJECT) +class KotlinTestCaseDisplayService(private val project: Project) : TestCaseDisplayService { + private var report: Report? = null + + private val unselectedTestCases = HashMap() + + private var mainPanel: JPanel = JPanel() + + private val topButtonsPanelFactory = TopButtonsPanelFactory(project, SupportedLanguage.Kotlin) + + private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) + + private var allTestCasePanel: JPanel = JPanel() + + private var scrollPane: JBScrollPane = JBScrollPane( + allTestCasePanel, + JBScrollPane.VERTICAL_SCROLLBAR_ALWAYS, + JBScrollPane.HORIZONTAL_SCROLLBAR_NEVER, + ) + + private var testCasePanels: HashMap = HashMap() + + private var testsSelected: Int = 0 + + /** + * Default color for the editors in the tool window + */ + private var defaultEditorColor: Color? = null + + /** + * Content Manager to be able to add / remove tabs from tool window + */ + private var contentManager: ContentManager? = null + + /** + * Variable to keep reference to the coverage visualisation content + */ + private var content: Content? = null + + var uiContext: UIContext? = null + + init { + allTestCasePanel.layout = BoxLayout(allTestCasePanel, BoxLayout.Y_AXIS) + mainPanel.layout = BorderLayout() + + mainPanel.add(topButtonsPanelFactory.getPanel(), BorderLayout.NORTH) + mainPanel.add(scrollPane, BorderLayout.CENTER) + + applyButton.isOpaque = false + applyButton.isContentAreaFilled = false + mainPanel.add(applyButton, BorderLayout.SOUTH) + + applyButton.addActionListener { applyTests() } + } + + override fun displayTestCases(report: Report, uiContext: UIContext, language: SupportedLanguage) { + this.report = report + this.uiContext = uiContext + + val editor = project.service().editor!! + + allTestCasePanel.removeAll() + testCasePanels.clear() + + addSeparator() + + // TestCasePanelFactories array + val testCasePanelFactories = arrayListOf() + + report.testCaseList.values.forEach { + val testCase = it + val testCasePanel = JPanel() + testCasePanel.layout = BorderLayout() + + // Add a checkbox to select the test + val checkbox = JCheckBox() + checkbox.isSelected = true + checkbox.addItemListener { + // Update the number of selected tests + testsSelected -= (1 - 2 * checkbox.isSelected.compareTo(false)) + + if (checkbox.isSelected) { + ReportHelper.selectTestCase(project, report, unselectedTestCases, testCase.id) + } else { + ReportHelper.unselectTestCase(project, report, unselectedTestCases, testCase.id) + } + + updateUI() + } + testCasePanel.add(checkbox, BorderLayout.WEST) + + val testCasePanelFactory = + TestCasePanelFactory(project, language, testCase, editor, checkbox, uiContext, report) + testCasePanel.add(testCasePanelFactory.getUpperPanel(), BorderLayout.NORTH) + testCasePanel.add(testCasePanelFactory.getMiddlePanel(), BorderLayout.CENTER) + testCasePanel.add(testCasePanelFactory.getBottomPanel(), BorderLayout.SOUTH) + + testCasePanelFactories.add(testCasePanelFactory) + + testCasePanel.add(Box.createRigidArea(Dimension(12, 0)), BorderLayout.EAST) + + // Add panel to parent panel + testCasePanel.maximumSize = Dimension(Short.MAX_VALUE.toInt(), Short.MAX_VALUE.toInt()) + allTestCasePanel.add(testCasePanel) + addSeparator() + testCasePanels[testCase.testName] = testCasePanel + } + + // Update the number of selected tests (all tests are selected by default) + testsSelected = testCasePanels.size + + topButtonsPanelFactory.setTestCasePanelFactoriesArray(testCasePanelFactories) + topButtonsPanelFactory.updateTopLabels() + + createToolWindowTab() + } + + override fun addSeparator() { + allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) + allTestCasePanel.add(JSeparator(SwingConstants.HORIZONTAL)) + allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) + } + + override fun highlightTestCase(name: String) { + val myPanel = testCasePanels[name] ?: return + openToolWindowTab() + scrollToPanel(myPanel) + + val editor = getEditor(name) ?: return + val settingsProjectState = project.service().state + val highlightColor = + JBColor( + PluginSettingsBundle.get("colorName"), + Color( + settingsProjectState.colorRed, + settingsProjectState.colorGreen, + settingsProjectState.colorBlue, + 30, + ), + ) + if (editor.background.equals(highlightColor)) return + defaultEditorColor = editor.background + editor.background = highlightColor + returnOriginalEditorBackground(editor) + } + + override fun openToolWindowTab() { + val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") + contentManager = toolWindowManager!!.contentManager + if (content != null) { + toolWindowManager.show() + toolWindowManager.contentManager.setSelectedContent(content!!) + } + } + + override fun scrollToPanel(myPanel: JPanel) { + var sum = 0 + for (component in allTestCasePanel.components) { + if (component == myPanel) { + break + } else { + sum += component.height + } + } + val scroll = scrollPane.verticalScrollBar + scroll.value = (scroll.minimum + scroll.maximum) * sum / allTestCasePanel.height + } + + override fun removeAllHighlights() { + project.service().editor?.markupModel?.removeAllHighlighters() + } + + override fun returnOriginalEditorBackground(editor: EditorTextField) { + Thread { + Thread.sleep(10000) + editor.background = defaultEditorColor + }.start() + } + + override fun highlightCoveredMutants(names: List) { + names.forEach { + highlightTestCase(it) + } + } + + override fun applyTests() { + // Filter the selected test cases + val selectedTestCasePanels = testCasePanels.filter { (it.value.getComponent(0) as JCheckBox).isSelected } + val selectedTestCases = selectedTestCasePanels.map { it.key } + + // Get the test case components (source code of the tests) + val testCaseComponents = selectedTestCases + .map { getEditor(it)!! } + .map { it.document.text } + + // Descriptor for choosing folders and java files + val descriptor = FileChooserDescriptor(true, true, false, false, false, false) + + // Apply filter with folders and java files with main class + WriteCommandAction.runWriteCommandAction(project) { + descriptor.withFileFilter { file -> + file.isDirectory || ( + file.extension?.lowercase(Locale.getDefault()) == "kotlin" && ( + PsiManager.getInstance(project).findFile(file!!) as KtFile + ).classes.stream().map { it.name } + .toArray() + .contains( + ( + PsiManager.getInstance(project) + .findFile(file) as PsiJavaFile + ).name.removeSuffix(".kt"), + ) + ) + } + } + + val fileChooser = FileChooser.chooseFiles( + descriptor, + project, + LocalFileSystem.getInstance().findFileByPath(project.basePath!!), + ) + + /** + * Cancel button pressed + */ + if (fileChooser.isEmpty()) return + + /** + * Chosen files by user + */ + val chosenFile = fileChooser[0] + + /** + * Virtual file of a final java file + */ + var virtualFile: VirtualFile? = null + + /** + * PsiClass of a final java file + */ + var ktClass: KtClass? = null + + /** + * PsiJavaFile of a final java file + */ + var psiKotlinFile: KtFile? = null + + if (chosenFile.isDirectory) { + // Input new file data + var className: String + var fileName: String + var filePath: String + // Waiting for correct file name input + while (true) { + val jOptionPane = + JOptionPane.showInputDialog( + null, + PluginLabelsBundle.get("optionPaneMessage"), + PluginLabelsBundle.get("optionPaneTitle"), + JOptionPane.PLAIN_MESSAGE, + null, + null, + null, + ) + + // Cancel button pressed + jOptionPane ?: return + + // Get class name from user + className = jOptionPane as String + + // Set file name and file path + fileName = "${className.split('.')[0]}.kt" + filePath = "${chosenFile.path}/$fileName" + + // Check the correctness of a class name + if (!Regex("[A-Z][a-zA-Z0-9]*(.kt)?").matches(className)) { + showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage")) + continue + } + + // Check the existence of a file with this name + if (File(filePath).exists()) { + showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage")) + continue + } + break + } + + // Create new file and set services of this file + WriteCommandAction.runWriteCommandAction(project) { + chosenFile.createChildData(null, fileName) + virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!! + psiKotlinFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as KtFile) + + val ktPsiFactory = KtPsiFactory(project) + ktClass = ktPsiFactory.createClass("class ${className.split(".")[0]} {}") + + if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { + val annotationEntry = + ktPsiFactory.createAnnotationEntry("@RunWith(${uiContext!!.testGenerationOutput.runWith})") + ktClass!!.addBefore(annotationEntry, ktClass!!.body) + } + + psiKotlinFile!!.add(ktClass!!) + } + } else { + // Set services of the chosen file + virtualFile = chosenFile + psiKotlinFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as KtFile) + val classNameNoSuffix = psiKotlinFile!!.name.removeSuffix(".kt") + ktClass = psiKotlinFile?.declarations?.filterIsInstance()?.find { it.name == classNameNoSuffix } + } + + // Add tests to the file + WriteCommandAction.runWriteCommandAction(project) { + appendTestsToClass(testCaseComponents, KotlinPsiClassWrapper(ktClass as KtClass), psiKotlinFile!!) + } + + // Remove the selected test cases from the cache and the tool window UI + removeSelectedTestCases(selectedTestCasePanels) + + // Open the file after adding + FileEditorManager.getInstance(project).openTextEditor( + OpenFileDescriptor(project, virtualFile!!), + true, + ) + } + + override fun showErrorWindow(message: String) { + JOptionPane.showMessageDialog( + null, + message, + PluginLabelsBundle.get("errorWindowTitle"), + JOptionPane.ERROR_MESSAGE, + ) + } + + override fun getEditor(testCaseName: String): EditorTextField? { + val middlePanelComponent = testCasePanels[testCaseName]?.getComponent(2) ?: return null + val middlePanel = middlePanelComponent as JPanel + return (middlePanel.getComponent(1) as JBScrollPane).viewport.view as EditorTextField + } + + override fun appendTestsToClass( + testCaseComponents: List, + selectedClass: PsiClassWrapper, + outputFile: PsiFile, + ) { + // block document + PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument( + PsiDocumentManager.getInstance(project).getDocument(outputFile as KtFile)!!, + ) + + // insert tests to a code + testCaseComponents.reversed().forEach { + val testMethodCode = + KotlinClassBuilderHelper.extractFirstTestMethodCode( + KotlinClassBuilderHelper.formatCode( + project, + it.replace("\r\n", "\n") + .replace("verifyException(", "// verifyException("), + uiContext!!.testGenerationOutput, + ), + ) + // Fix Windows line separators + .replace("\r\n", "\n") + + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + selectedClass.rBrace!!, + testMethodCode, + ) + } + + // insert other info to a code + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + selectedClass.rBrace!!, + uiContext!!.testGenerationOutput.otherInfo + "\n", + ) + + // Create the imports string + val importsString = uiContext!!.testGenerationOutput.importsCode.joinToString("\n") + "\n\n" + + // Find the insertion offset + val insertionOffset = outputFile.importList?.startOffset + ?: outputFile.packageDirective?.endOffset + ?: 0 + + // Insert the imports into the document + PsiDocumentManager.getInstance(project).getDocument(outputFile)?.let { document -> + document.insertString(insertionOffset, importsString) + PsiDocumentManager.getInstance(project).commitDocument(document) + } + + val packageName = uiContext!!.testGenerationOutput.packageName + val packageStatement = if (packageName.isEmpty()) "" else "package $packageName\n\n" + + // Insert the package statement at the beginning of the document + PsiDocumentManager.getInstance(project).getDocument(outputFile)?.let { document -> + document.insertString(0, packageStatement) + PsiDocumentManager.getInstance(project).commitDocument(document) + } + } + + override fun updateEditorForFileUrl(fileUrl: String) { + val documentManager = FileDocumentManager.getInstance() + // https://intellij-support.jetbrains.com/hc/en-us/community/posts/360004480599/comments/360000703299 + FileEditorManager.getInstance(project).selectedEditors.map { it as TextEditor }.map { it.editor }.map { + val currentFile = documentManager.getFile(it.document) + if (currentFile != null) { + if (currentFile.presentableUrl == fileUrl) { + project.service().editor = it + } + } + } + } + + override fun createToolWindowTab() { + // Remove generated tests tab from content manager if necessary + val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") + contentManager = toolWindowManager!!.contentManager + if (content != null) { + contentManager!!.removeContent(content!!, true) + } + + // If there is no generated tests tab, make it + val contentFactory: ContentFactory = ContentFactory.getInstance() + content = contentFactory.createContent( + mainPanel, + PluginLabelsBundle.get("generatedTests"), + true, + ) + contentManager!!.addContent(content!!) + + // Focus on generated tests tab and open toolWindow if not opened already + contentManager!!.setSelectedContent(content!!) + toolWindowManager.show() + } + + override fun closeToolWindow() { + contentManager?.removeContent(content!!, true) + ToolWindowManager.getInstance(project).getToolWindow("TestSpark")?.hide() + val coverageVisualisationService = project.service() + coverageVisualisationService.closeToolWindowTab() + } + + override fun removeSelectedTestCases(selectedTestCasePanels: Map) { + selectedTestCasePanels.forEach { removeTestCase(it.key) } + removeAllHighlights() + closeToolWindow() + } + + override fun clear() { + // Remove the tests + val testCasePanelsToRemove = testCasePanels.toMap() + removeSelectedTestCases(testCasePanelsToRemove) + + topButtonsPanelFactory.clear() + } + + override fun removeTestCase(testCaseName: String) { + // Update the number of selected test cases if necessary + if ((testCasePanels[testCaseName]!!.getComponent(0) as JCheckBox).isSelected) { + testsSelected-- + } + + // Remove the test panel from the UI + allTestCasePanel.remove(testCasePanels[testCaseName]) + + // Remove the test panel + testCasePanels.remove(testCaseName) + } + + override fun updateUI() { + // Update the UI of the tool window tab + allTestCasePanel.updateUI() + + topButtonsPanelFactory.updateTopLabels() + + // If no more tests are remaining, close the tool window + if (testCasePanels.size == 0) closeToolWindow() + } + + override fun getTestCasePanels() = testCasePanels + + override fun getTestsSelected() = testsSelected + + override fun setTestsSelected(testsSelected: Int) { + this.testsSelected = testsSelected + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt index 89e480e83..6c3d77a05 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt @@ -45,7 +45,7 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent { // Models private var modelSelector = ComboBox(arrayOf("")) - private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName)) + private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName)) // Default LLM Requests private var defaultLLMRequestsSeparator = diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt index 5f792b328..2b0ff5769 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt @@ -42,6 +42,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab settingsComponent!!.llmPlatforms[index].token = llmSettingsState.grazieToken settingsComponent!!.llmPlatforms[index].model = llmSettingsState.grazieModel } + if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + settingsComponent!!.llmPlatforms[index].token = llmSettingsState.huggingFaceToken + settingsComponent!!.llmPlatforms[index].model = llmSettingsState.huggingFaceModel + } } settingsComponent!!.currentLLMPlatformName = llmSettingsState.currentLLMPlatformName settingsComponent!!.maxLLMRequest = llmSettingsState.maxLLMRequest @@ -81,6 +85,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab modified = modified or (settingsComponent!!.llmPlatforms[index].token != llmSettingsState.grazieToken) modified = modified or (settingsComponent!!.llmPlatforms[index].model != llmSettingsState.grazieModel) } + if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + modified = modified or (settingsComponent!!.llmPlatforms[index].token != llmSettingsState.huggingFaceToken) + modified = modified or (settingsComponent!!.llmPlatforms[index].model != llmSettingsState.huggingFaceModel) + } } modified = modified or (settingsComponent!!.currentLLMPlatformName != llmSettingsState.currentLLMPlatformName) modified = modified or (settingsComponent!!.maxLLMRequest != llmSettingsState.maxLLMRequest) @@ -138,6 +146,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab llmSettingsState.grazieToken = settingsComponent!!.llmPlatforms[index].token llmSettingsState.grazieModel = settingsComponent!!.llmPlatforms[index].model } + if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + llmSettingsState.huggingFaceToken = settingsComponent!!.llmPlatforms[index].token + llmSettingsState.huggingFaceModel = settingsComponent!!.llmPlatforms[index].model + } } llmSettingsState.currentLLMPlatformName = settingsComponent!!.currentLLMPlatformName llmSettingsState.maxLLMRequest = settingsComponent!!.maxLLMRequest diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt index 3ce378707..590ec3c1d 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt @@ -15,6 +15,9 @@ data class LLMSettingsState( var grazieName: String = DefaultLLMSettingsState.grazieName, var grazieToken: String = DefaultLLMSettingsState.grazieToken, var grazieModel: String = DefaultLLMSettingsState.grazieModel, + var huggingFaceName: String = DefaultLLMSettingsState.huggingFaceName, + var huggingFaceToken: String = DefaultLLMSettingsState.huggingFaceToken, + var huggingFaceModel: String = DefaultLLMSettingsState.huggingFaceModel, var currentLLMPlatformName: String = DefaultLLMSettingsState.currentLLMPlatformName, var maxLLMRequest: Int = DefaultLLMSettingsState.maxLLMRequest, var maxInputParamsDepth: Int = DefaultLLMSettingsState.maxInputParamsDepth, @@ -45,6 +48,9 @@ data class LLMSettingsState( val grazieName: String = LLMDefaultsBundle.get("grazieName") val grazieToken: String = LLMDefaultsBundle.get("grazieToken") val grazieModel: String = LLMDefaultsBundle.get("grazieModel") + val huggingFaceName: String = LLMDefaultsBundle.get("huggingFaceName") + val huggingFaceToken: String = LLMDefaultsBundle.get("huggingFaceToken") + val huggingFaceModel: String = LLMDefaultsBundle.get("huggingFaceModel") var currentLLMPlatformName: String = LLMDefaultsBundle.get("openAIName") val maxLLMRequest: Int = LLMDefaultsBundle.get("maxLLMRequest").toInt() val maxInputParamsDepth: Int = LLMDefaultsBundle.get("maxInputParamsDepth").toInt() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt index 0cd1b073a..c4310ba61 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt @@ -2,7 +2,7 @@ package org.jetbrains.research.testspark.tools import com.intellij.openapi.application.PathManager import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.dependencies.JavaTestCompilationDependencies +import org.jetbrains.research.testspark.core.test.data.dependencies.TestCompilationDependencies import java.io.File /** @@ -16,7 +16,7 @@ class LibraryPathsProvider { private val sep = File.separatorChar private val libPrefix = "${PathManager.getPluginsPath()}${sep}TestSpark${sep}lib$sep" - fun getTestCompilationLibraryPaths() = JavaTestCompilationDependencies.getJarDescriptors().map { descriptor -> + fun getTestCompilationLibraryPaths() = TestCompilationDependencies.getJarDescriptors().map { descriptor -> "$libPrefix${sep}${descriptor.name}" } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt index aa5b694b7..30ed0ba6b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt @@ -6,12 +6,12 @@ import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.progress.Task import com.intellij.openapi.project.Project -import com.intellij.openapi.roots.ProjectFileIndex import com.intellij.openapi.roots.ProjectRootManager import com.intellij.openapi.util.io.FileUtilRt import org.jetbrains.research.testspark.actions.controllers.TestGenerationController import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.utils.DataFilesUtil import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.ProjectContext @@ -22,6 +22,8 @@ import org.jetbrains.research.testspark.services.CoverageVisualisationService import org.jetbrains.research.testspark.services.EditorService import org.jetbrains.research.testspark.services.TestCaseDisplayService import org.jetbrains.research.testspark.services.TestsExecutionResultService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService import org.jetbrains.research.testspark.tools.template.generation.ProcessManager import java.util.UUID @@ -29,7 +31,7 @@ import java.util.UUID * Pipeline class represents a pipeline for generating tests in a project. * * @param project the project in which the pipeline is executed. - * @param psiHelper The PsiHelper in the context of witch the pipeline is executed. + * @param psiHelper The PsiHelper in the context of which the pipeline is executed. * @param caretOffset the offset of the caret position in the PSI file. * @param fileUrl the URL of the file being processed, if applicable. * @param packageName the package name of the file being processed. @@ -47,7 +49,7 @@ class Pipeline( init { - val cutPsiClass = psiHelper.getSurroundingClass(caretOffset)!! + val cutPsiClass = psiHelper.getSurroundingClass(caretOffset) // get generated test path val testResultDirectory = "${FileUtilRt.getTempDirectory()}${ToolUtils.sep}testSparkResults${ToolUtils.sep}" @@ -57,10 +59,8 @@ class Pipeline( ApplicationManager.getApplication().runWriteAction { projectContext.projectClassPath = ProjectRootManager.getInstance(project).contentRoots.first().path projectContext.fileUrlAsString = fileUrl - projectContext.classFQN = cutPsiClass.qualifiedName - // TODO probably can be made easier - projectContext.cutModule = - ProjectFileIndex.getInstance(project).getModuleForFile(cutPsiClass.virtualFile)!! + cutPsiClass?.let { projectContext.classFQN = it.qualifiedName } + projectContext.cutModule = psiHelper.getModuleFromPsiFile() } generatedTestsData.resultPath = ToolUtils.getResultPath(id, testResultDirectory) @@ -108,14 +108,13 @@ class Pipeline( override fun onFinished() { super.onFinished() testGenerationController.finished() - uiContext?.let { - project.service() - .updateEditorForFileUrl(it.testGenerationOutput.fileUrl) - - if (project.service().editor != null) { - val report = it.testGenerationOutput.testGenerationResultList[0]!! - project.service().displayTestCases(report, it, psiHelper.language) - project.service().showCoverage(report) + when (psiHelper.language) { + SupportedLanguage.Java -> uiContext?.let { + displayTestCase(it) + } + + SupportedLanguage.Kotlin -> uiContext?.let { + displayTestCase(it) } } } @@ -124,8 +123,22 @@ class Pipeline( private fun clear(project: Project) { // should be removed totally! testGenerationController.errorMonitor.clear() - project.service().clear() + when (psiHelper.language) { + SupportedLanguage.Java -> project.service().clear() + SupportedLanguage.Kotlin -> project.service().clear() + } + project.service().clear() project.service().clear() } + + private inline fun displayTestCase(ctx: UIContext) { + project.service().updateEditorForFileUrl(ctx.testGenerationOutput.fileUrl) + + if (project.service().editor != null) { + val report = ctx.testGenerationOutput.testGenerationResultList[0]!! + project.service().displayTestCases(report, ctx, psiHelper.language) + project.service().showCoverage(report) + } + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt new file mode 100644 index 000000000..ea0c0bc2e --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt @@ -0,0 +1,17 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.java.JavaTestBodyPrinter +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestBodyPrinter + +class TestBodyPrinterFactory { + companion object { + fun create(language: SupportedLanguage): TestBodyPrinter { + return when (language) { + SupportedLanguage.Kotlin -> KotlinTestBodyPrinter() + SupportedLanguage.Java -> JavaTestBodyPrinter() + } + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt new file mode 100644 index 000000000..1b73c380c --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt @@ -0,0 +1,21 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.helpers.TestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeAnalyzer + +object TestClassCodeAnalyzerFactory { + /** + * Creates an instance of TestClassCodeAnalyzer for the specified language. + * + * @param language the programming language for which to create the analyzer + * @return an instance of TestClassCodeAnalyzer + */ + fun create(language: SupportedLanguage): TestClassCodeAnalyzer { + return when (language) { + SupportedLanguage.Kotlin -> KotlinTestClassCodeAnalyzer + SupportedLanguage.Java -> JavaTestClassCodeAnalyzer + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt new file mode 100644 index 000000000..56151e26e --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt @@ -0,0 +1,21 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.helpers.TestClassCodeGenerator +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeGenerator +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeGenerator + +object TestClassCodeGeneratorFactory { + /** + * Creates an instance of TestClassCodeGenerator for the specified language. + * + * @param language the programming language for which to create the generator + * @return an instance of TestClassCodeGenerator + */ + fun create(language: SupportedLanguage): TestClassCodeGenerator { + return when (language) { + SupportedLanguage.Kotlin -> KotlinTestClassCodeGenerator + SupportedLanguage.Java -> JavaTestClassCodeGenerator + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt index 8680370bd..84b512bb5 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt @@ -3,20 +3,31 @@ package org.jetbrains.research.testspark.tools import com.intellij.openapi.project.Project import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.test.java.JavaTestCompiler +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestCompiler class TestCompilerFactory { companion object { - fun createJavacTestCompiler( + fun create( project: Project, junitVersion: JUnitVersion, + language: SupportedLanguage, javaHomeDirectory: String? = null, ): TestCompiler { - val javaHomePath = javaHomeDirectory ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + val javaSDKHomePath = + javaHomeDirectory ?: ProjectRootManager.getInstance(project).projectSdk?.homeDirectory?.path + ?: throw RuntimeException("Java SDK not configured for the project.") + val libraryPaths = LibraryPathsProvider.getTestCompilationLibraryPaths() val junitLibraryPaths = LibraryPathsProvider.getJUnitLibraryPaths(junitVersion) - return TestCompiler(javaHomePath, libraryPaths, junitLibraryPaths) + // TODO add the warning window that for Java we always need the javaHomeDirectoryPath + return when (language) { + SupportedLanguage.Java -> JavaTestCompiler(libraryPaths, junitLibraryPaths, javaSDKHomePath) + SupportedLanguage.Kotlin -> KotlinTestCompiler(libraryPaths, junitLibraryPaths) + } } } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt index e0a4150b4..d35589357 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt @@ -8,6 +8,7 @@ import com.intellij.openapi.roots.CompilerModuleExtension import com.intellij.openapi.roots.ModuleRootManager import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.core.data.TestCase +import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.utils.CommandLineRunner import org.jetbrains.research.testspark.core.utils.DataFilesUtil @@ -25,16 +26,20 @@ class TestProcessor( val project: Project, givenProjectSDKPath: Path? = null, ) : TestsPersistentStorage { - private val javaHomeDirectory = givenProjectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + private val homeDirectory = + givenProjectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path private val log = Logger.getInstance(this::class.java) private val llmSettingsState: LLMSettingsState get() = project.getService(LLMSettingsService::class.java).state - val testCompiler = TestCompilerFactory.createJavacTestCompiler(project, llmSettingsState.junitVersion, javaHomeDirectory) - - override fun saveGeneratedTest(packageString: String, code: String, resultPath: String, testFileName: String): String { + override fun saveGeneratedTest( + packageString: String, + code: String, + resultPath: String, + testFileName: String, + ): String { // Generate the final path for the generated tests var generatedTestPath = "$resultPath${File.separatorChar}" packageString.split(".").forEach { directory -> @@ -69,14 +74,10 @@ class TestProcessor( generatedTestPackage: String, resultPath: String, projectContext: ProjectContext, + testCompiler: TestCompiler, ): String { // find the proper javac - val javaRunner = File(javaHomeDirectory).walk() - .filter { - val isJavaName = if (DataFilesUtil.isWindows()) it.name.equals("java.exe") else it.name.equals("java") - isJavaName && it.isFile - } - .first() + val javaRunner = findJavaCompilerInDirectory(homeDirectory) // JaCoCo libs val jacocoAgentLibraryPath = "\"${LibraryPathsProvider.getJacocoAgentLibraryPath()}\"" val jacocoCLILibraryPath = "\"${LibraryPathsProvider.getJacocoCliLibraryPath()}\"" @@ -90,13 +91,21 @@ class TestProcessor( val junitVersion = llmSettingsState.junitVersion.version // run the test method with jacoco agent + log.info("[TestProcessor] Executing $name") val junitRunnerLibraryPath = LibraryPathsProvider.getJUnitRunnerLibraryPath() + // classFQN will be null for the top level function + val javaAgentFlag = + if (projectContext.classFQN != null) { + "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false,includes=${projectContext.classFQN}" + } else { + "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false" + } val testExecutionError = CommandLineRunner.run( arrayListOf( javaRunner.absolutePath, - "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false,includes=${projectContext.classFQN}", + javaAgentFlag, "-cp", - "\"${testCompiler.getPath(projectBuildPath)}${DataFilesUtil.classpathSeparator}${junitRunnerLibraryPath}${DataFilesUtil.classpathSeparator}$resultPath\"", + "\"${testCompiler.getClassPaths(projectBuildPath)}${DataFilesUtil.classpathSeparator}${junitRunnerLibraryPath}${DataFilesUtil.classpathSeparator}$resultPath\"", "org.jetbrains.research.SingleJUnitTestRunner$junitVersion", name, ), @@ -148,9 +157,10 @@ class TestProcessor( testId: Int, testName: String, testCode: String, - packageLine: String, + packageName: String, resultPath: String, projectContext: ProjectContext, + testCompiler: TestCompiler, ): TestCase { // get buildPath var buildPath: String = ProjectRootManager.getInstance(project).contentRoots.first().path @@ -161,7 +171,7 @@ class TestProcessor( // save new test to file val generatedTestPath: String = saveGeneratedTest( - packageLine, + packageName, testCode, resultPath, fileName, @@ -179,9 +189,10 @@ class TestProcessor( dataFileName, testName, buildPath, - packageLine, + packageName, resultPath, projectContext, + testCompiler, ) if (!File("$dataFileName.xml").exists()) { @@ -230,7 +241,8 @@ class TestProcessor( frames.removeFirst() frames.forEach { frame -> - if (frame.contains(projectContext.classFQN!!)) { + // classFQN will be null for the top level function + if (projectContext.classFQN != null && frame.contains(projectContext.classFQN!!)) { val coveredLineNumber = frame.split(":")[1].replace(")", "").toIntOrNull() if (coveredLineNumber != null) { result.add(coveredLineNumber) @@ -274,7 +286,8 @@ class TestProcessor( children("counter") {} } children("sourcefile") { - isCorrectSourceFile = this.attributes.getValue("name") == projectContext.fileUrlAsString!!.split(File.separatorChar).last() + isCorrectSourceFile = + this.attributes.getValue("name") == projectContext.fileUrlAsString!!.split(File.separatorChar).last() children("line") { if (isCorrectSourceFile && this.attributes.getValue("mi") == "0") { setOfLines.add(this.attributes.getValue("nr").toInt()) @@ -295,4 +308,18 @@ class TestProcessor( return TestCase(testCaseId, testCaseName, testCaseCode, setOfLines) } + + /** + * Finds 'javac' compiler (both on Unix & Windows) + * starting from the provided directory. + */ + private fun findJavaCompilerInDirectory(homeDirectory: String): File { + return File(homeDirectory).walk() + .filter { + val isJavaName = + if (DataFilesUtil.isWindows()) it.name.equals("java.exe") else it.name.equals("java") + isJavaName && it.isFile + } + .first() + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestSuiteParserFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestSuiteParserFactory.kt new file mode 100644 index 000000000..3c4ca5637 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestSuiteParserFactory.kt @@ -0,0 +1,31 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.TestSuiteParser +import org.jetbrains.research.testspark.core.test.java.JavaJUnitTestSuiteParser +import org.jetbrains.research.testspark.core.test.kotlin.KotlinJUnitTestSuiteParser + +class TestSuiteParserFactory { + companion object { + fun createJUnitTestSuiteParser( + jUnitVersion: JUnitVersion, + language: SupportedLanguage, + testBodyPrinter: TestBodyPrinter, + packageName: String = "", + ): TestSuiteParser = when (language) { + SupportedLanguage.Java -> JavaJUnitTestSuiteParser( + packageName, + jUnitVersion, + testBodyPrinter, + ) + + SupportedLanguage.Kotlin -> KotlinJUnitTestSuiteParser( + packageName, + jUnitVersion, + testBodyPrinter, + ) + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt new file mode 100644 index 000000000..a896d273c --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt @@ -0,0 +1,18 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.TestSuiteParser +import org.jetbrains.research.testspark.tools.llm.generation.JUnitTestsAssembler + +class TestsAssemblerFactory { + companion object { + fun create( + indicator: CustomProgressIndicator, + generationData: TestGenerationData, + testSuiteParser: TestSuiteParser, + junitVersion: JUnitVersion, + ) = JUnitTestsAssembler(indicator, generationData, testSuiteParser, junitVersion) + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt index 3ba26b9c5..a7ef25eb2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt @@ -11,9 +11,9 @@ import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.utils.DataFilesUtil import org.jetbrains.research.testspark.data.IJTestCase -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper import org.jetbrains.research.testspark.services.TestsExecutionResultService import java.io.File @@ -21,68 +21,37 @@ object ToolUtils { val sep = File.separatorChar val pathSep = File.pathSeparatorChar - /** - * Retrieves the imports code from a given test suite code. - * - * @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned. - * @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result. - * @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned. - */ - fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String): MutableSet { - testSuiteCode ?: return mutableSetOf() - return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence() - .filter { it.contains("^import".toRegex()) } - .filterNot { it.contains("evosuite".toRegex()) } - .filterNot { it.contains("RunWith".toRegex()) } - .filterNot { it.contains(classFQN.toRegex()) }.toMutableSet() - } - - /** - * Retrieves the package declaration from the given test suite code. - * - * @param testSuiteCode The generated code of the test suite. - * @return The package declaration extracted from the test suite code, or an empty string if no package declaration was found. - */ -// get package from a generated code - fun getPackageFromTestSuiteCode(testSuiteCode: String?): String { - testSuiteCode ?: return "" - if (!testSuiteCode.contains("package")) return "" - val result = testSuiteCode.replace("\r\n", "\n").split("\n") - .filter { it.contains("^package".toRegex()) }.joinToString("").split("package ")[1].split(";")[0] - if (result.isBlank()) return "" - return result - } - /** * Saves the data related to test generation in the specified project's workspace. * * @param project The project in which the test generation data will be saved. * @param report The report object to be added to the test generation result list. - * @param packageLine The package declaration line of the test generation data. + * @param packageName The package declaration line of the test generation data. * @param importsCode The import statements code of the test generation data. */ fun saveData( project: Project, report: Report, - packageLine: String, + packageName: String, importsCode: MutableSet, fileUrl: String, generatedTestData: TestGenerationData, + language: SupportedLanguage = SupportedLanguage.Java, ) { generatedTestData.fileUrl = fileUrl - generatedTestData.packageLine = packageLine + generatedTestData.packageName = packageName generatedTestData.importsCode.addAll(importsCode) project.service().initExecutionResult(report.testCaseList.values.map { it.id }) for (testCase in report.testCaseList.values) { val code = testCase.testCode - testCase.testCode = JavaClassBuilderHelper.generateCode( + testCase.testCode = TestClassCodeGeneratorFactory.create(language).generateCode( project, getClassWithTestCaseName(testCase.testName), code, generatedTestData.importsCode, - generatedTestData.packageLine, + generatedTestData.packageName, generatedTestData.runWith, generatedTestData.otherInfo, generatedTestData, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt index 46b982ac1..4e4c75a75 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt @@ -5,7 +5,7 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.actions.controllers.TestGenerationController -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt index 20b47872c..18a593ca2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuiteSettingsArguments.kt @@ -175,4 +175,4 @@ class EvoSuiteSettingsArguments( return if (command == "-Dcriterion=") "-Dcriterion=LINE" else command } } -} \ No newline at end of file +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt index c1e5e6560..8c180f9df 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt @@ -15,10 +15,13 @@ import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteDefaultsBundle import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteMessagesBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.core.utils.CommandLineRunner -import org.jetbrains.research.testspark.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.IJReport import org.jetbrains.research.testspark.data.ProjectContext @@ -200,8 +203,8 @@ class EvoSuiteProcessManager( ToolUtils.saveData( project, IJReport(testGenerationResult), - ToolUtils.getPackageFromTestSuiteCode(testGenerationResult.testSuiteCode), - ToolUtils.getImportsCodeFromTestSuiteCode(testGenerationResult.testSuiteCode, classFQN), + getPackageFromTestSuiteCode(testGenerationResult.testSuiteCode, SupportedLanguage.Java), + getImportsCodeFromTestSuiteCode(testGenerationResult.testSuiteCode, classFQN), projectContext.fileUrlAsString!!, generatedTestsData, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt index 01f16176c..89a27df64 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt @@ -1,11 +1,12 @@ package org.jetbrains.research.testspark.tools.llm import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.actions.controllers.TestGenerationController import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.helpers.LLMHelper import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper @@ -23,6 +24,8 @@ import java.nio.file.Path */ class Llm(override val name: String = "LLM") : Tool { + private val log = Logger.getInstance(this::class.java) + /** * Returns an instance of the LLMProcessManager. * @@ -74,6 +77,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests for CLASS was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -107,6 +111,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests for METHOD was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -141,6 +146,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests for LINE was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -174,9 +180,7 @@ class Llm(override val name: String = "LLM") : Tool { fileUrl: String?, testGenerationController: TestGenerationController, ): Pipeline { - val cutPsiClass = psiHelper.getSurroundingClass(caretOffset)!! - val packageList = cutPsiClass.qualifiedName.split(".").dropLast(1) - val packageName = packageList.joinToString(".") + val packageName = psiHelper.getPackageName() return Pipeline(project, psiHelper, caretOffset, fileUrl, packageName, testGenerationController) } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt index 921980ed6..271cf4b49 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt @@ -57,6 +57,7 @@ class LlmSettingsArguments(private val project: Project) { fun getToken(): String = when (currentLLMPlatformName()) { llmSettingsState.openAIName -> llmSettingsState.openAIToken llmSettingsState.grazieName -> llmSettingsState.grazieToken + llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceToken else -> "" } @@ -68,6 +69,7 @@ class LlmSettingsArguments(private val project: Project) { fun getModel(): String = when (currentLLMPlatformName()) { llmSettingsState.openAIName -> llmSettingsState.openAIModel llmSettingsState.grazieName -> llmSettingsState.grazieModel + llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceModel else -> "" } -} \ No newline at end of file +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt index e1bcb67ec..1196016b2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt @@ -1,36 +1,27 @@ package org.jetbrains.research.testspark.tools.llm.generation import com.intellij.openapi.diagnostic.Logger -import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.TestSuiteParser import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.java.JavaJUnitTestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.kotlin.KotlinJUnitTestSuiteParser -import org.jetbrains.research.testspark.core.utils.Language -import org.jetbrains.research.testspark.core.utils.javaImportPattern -import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState /** * Assembler class for generating and organizing test cases. * - * @property project The project to which the tests belong. * @property indicator The progress indicator to display the progress of test generation. * @property log The logger for logging debug information. * @property lastTestCount The count of the last generated tests. */ class JUnitTestsAssembler( - val project: Project, val indicator: CustomProgressIndicator, - val generationData: TestGenerationData, + private val generationData: TestGenerationData, + private val testSuiteParser: TestSuiteParser, + val junitVersion: JUnitVersion, ) : TestsAssembler() { - private val llmSettingsState: LLMSettingsState - get() = project.getService(LLMSettingsService::class.java).state private val log: Logger = Logger.getInstance(this.javaClass) @@ -58,11 +49,8 @@ class JUnitTestsAssembler( } } - override fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? { - val junitVersion = llmSettingsState.junitVersion - - val parser = createTestSuiteParser(packageName, junitVersion, language) - val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(super.getContent()) + override fun assembleTestSuite(): TestSuiteGeneratedByLLM? { + val testSuite = testSuiteParser.parseTestSuite(super.getContent()) // save RunWith if (testSuite?.runWith?.isNotBlank() == true) { @@ -80,15 +68,4 @@ class JUnitTestsAssembler( testSuite?.testCases?.forEach { testCase -> log.info("Generated test case: $testCase") } return testSuite } - - private fun createTestSuiteParser( - packageName: String, - jUnitVersion: JUnitVersion, - language: Language, - ): TestSuiteParser { - return when (language) { - Language.Java -> JavaJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern) - Language.Kotlin -> KotlinJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern) - } - } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt index bb1dee0ff..f46dd5603 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt @@ -3,25 +3,32 @@ package org.jetbrains.research.testspark.tools.llm.generation import com.intellij.openapi.components.service import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project +import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.FeedbackCycleExecutionResult import org.jetbrains.research.testspark.core.generation.llm.LLMWithFeedbackCycle +import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeReductionStrategy import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.test.TestsPresenter import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.IJReport import org.jetbrains.research.testspark.data.ProjectContext import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.services.PluginSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestBodyPrinterFactory +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor +import org.jetbrains.research.testspark.tools.TestSuiteParserFactory +import org.jetbrains.research.testspark.tools.TestsAssemblerFactory import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager @@ -34,7 +41,6 @@ import java.nio.file.Path * and is responsible for generating tests using the LLM tool. * * @property project The project in which the test generation is being performed. - * @property prompt The prompt to be sent to the LLM tool. * @property testFileName The name of the generated test file. * @property log An instance of the logger class for logging purposes. * @property llmErrorManager An instance of the LLMErrorManager class. @@ -42,19 +48,23 @@ import java.nio.file.Path */ class LLMProcessManager( private val project: Project, - private val language: Language, + private val language: SupportedLanguage, private val promptManager: PromptManager, private val testSamplesCode: String, - projectSDKPath: Path? = null, + private val projectSDKPath: Path? = null, ) : ProcessManager { - private val llmSettingsState: LLMSettingsState - get() = project.getService(LLMSettingsService::class.java).state - private val testFileName: String = "GeneratedTest.java" + private val homeDirectory = + projectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + + private val testFileName: String = when (language) { + SupportedLanguage.Java -> "GeneratedTest.java" + SupportedLanguage.Kotlin -> "GeneratedTest.kt" + } private val log = Logger.getInstance(this::class.java) private val llmErrorManager: LLMErrorManager = LLMErrorManager() private val maxRequests = LlmSettingsArguments(project).maxLLMRequest() - private val testProcessor = TestProcessor(project, projectSDKPath) + private val testProcessor: TestsPersistentStorage = TestProcessor(project, projectSDKPath) /** * Runs the test generator process. @@ -91,16 +101,16 @@ class LLMProcessManager( val report = IJReport() // PROMPT GENERATION - val initialPromptMessage = promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) - - val testCompiler = testProcessor.testCompiler + val initialPromptMessage = + promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) // initiate a new RequestManager val requestManager = StandardRequestManagerFactory(project).getRequestManager(project) // adapter for the existing prompt reduction functionality val promptSizeReductionStrategy = object : PromptSizeReductionStrategy { - override fun isReductionPossible(): Boolean = promptManager.isPromptSizeReductionPossible(generatedTestsData) + override fun isReductionPossible(): Boolean = + promptManager.isPromptSizeReductionPossible(generatedTestsData) override fun reduceSizeAndGeneratePrompt(): String { if (!isReductionPossible()) { @@ -115,7 +125,7 @@ class LLMProcessManager( // adapter for the existing test case/test suite string representing functionality val testsPresenter = object : TestsPresenter { - private val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData) + private val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) override fun representTestSuite(testSuite: TestSuiteGeneratedByLLM): String { return testSuitePresenter.toStringWithoutExpectedException(testSuite) @@ -126,6 +136,29 @@ class LLMProcessManager( } } + // Creation of JUnit specific parser, printer and assembler + val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + val testBodyPrinter = TestBodyPrinterFactory.create(language) + val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( + jUnitVersion, + language, + testBodyPrinter, + packageName, + ) + val testsAssembler = TestsAssemblerFactory.create( + indicator, + generatedTestsData, + testSuiteParser, + jUnitVersion, + ) + + val testCompiler = TestCompilerFactory.create( + project, + jUnitVersion, + language, + homeDirectory, + ) + // Asking LLM to generate a test suite. Here we have a feedback cycle for LLM in case of wrong responses val llmFeedbackCycle = LLMWithFeedbackCycle( language = language, @@ -137,7 +170,7 @@ class LLMProcessManager( resultPath = generatedTestsData.resultPath, buildPath = buildPath, requestManager = requestManager, - testsAssembler = JUnitTestsAssembler(project, indicator, generatedTestsData), + testsAssembler = testsAssembler, testCompiler = testCompiler, testStorage = testProcessor, testsPresenter = testsPresenter, @@ -150,8 +183,10 @@ class LLMProcessManager( when (warning) { LLMWithFeedbackCycle.WarningType.TEST_SUITE_PARSING_FAILED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("emptyResponse"), project) + LLMWithFeedbackCycle.WarningType.NO_TEST_CASES_GENERATED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("emptyResponse"), project) + LLMWithFeedbackCycle.WarningType.COMPILATION_ERROR_OCCURRED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("compilationError"), project) } @@ -167,17 +202,21 @@ class LLMProcessManager( // store compilable test cases generatedTestsData.compilableTestCases.addAll(feedbackResponse.compilableTestCases) } + FeedbackCycleExecutionResult.NO_COMPILABLE_TEST_CASES_GENERATED -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("invalidLLMResult"), project, errorMonitor) } + FeedbackCycleExecutionResult.CANCELED -> { log.info("Process stopped") return null } + FeedbackCycleExecutionResult.PROVIDED_PROMPT_TOO_LONG -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("tooLongPromptRequest"), project, errorMonitor) return null } + FeedbackCycleExecutionResult.SAVING_TEST_FILES_ISSUE -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("savingTestFileIssue"), project, errorMonitor) } @@ -190,7 +229,7 @@ class LLMProcessManager( log.info("Save generated test suite and test cases into the project workspace") - val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData) + val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) val generatedTestSuite: TestSuiteGeneratedByLLM? = feedbackResponse.generatedTestSuite val testSuiteRepresentation = if (generatedTestSuite != null) testSuitePresenter.toString(generatedTestSuite) else null @@ -200,10 +239,11 @@ class LLMProcessManager( ToolUtils.saveData( project, report, - ToolUtils.getPackageFromTestSuiteCode(testSuiteCode = testSuiteRepresentation), - ToolUtils.getImportsCodeFromTestSuiteCode(testSuiteRepresentation, projectContext.classFQN!!), + getPackageFromTestSuiteCode(testSuiteCode = testSuiteRepresentation, language), + getImportsCodeFromTestSuiteCode(testSuiteRepresentation, projectContext.classFQN), projectContext.fileUrlAsString!!, generatedTestsData, + language, ) return UIContext(projectContext, generatedTestsData, requestManager, errorMonitor) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt index d7ac8f9f5..08e5be765 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt @@ -5,7 +5,6 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.openapi.util.Computable import com.intellij.openapi.util.TextRange -import com.intellij.psi.PsiDocumentManager import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.llm.LLMSettingsBundle import org.jetbrains.research.testspark.core.data.TestGenerationData @@ -15,7 +14,7 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptConfiguration import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptGenerationContext import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptTemplates -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.llm.JsonEncoding import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper @@ -31,7 +30,7 @@ import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager * A class that manages prompts for generating unit tests. * * @constructor Creates a PromptManager with the given parameters. - * @param psiHelper The PsiHelper in the context of witch the pipeline is executed. + * @param psiHelper The PsiHelper in the context of which the pipeline is executed. * @param caret The place of the caret. */ class PromptManager( @@ -39,6 +38,9 @@ class PromptManager( private val psiHelper: PsiHelper, private val caret: Int, ) { + /** + * The `classesToTest` is empty when we work with the function outside the class + */ private val classesToTest: List get() { val classesToTest = mutableListOf() @@ -52,7 +54,10 @@ class PromptManager( return classesToTest } - private val cut: PsiClassWrapper = classesToTest[0] + /** + * The `cut` is null when we work with the function outside the class. + */ + private val cut: PsiClassWrapper? = if (classesToTest.isNotEmpty()) classesToTest[0] else null private val llmSettingsState: LLMSettingsState get() = project.getService(LLMSettingsService::class.java).state @@ -79,7 +84,7 @@ class PromptManager( .toMap() val context = PromptGenerationContext( - cut = createClassRepresentation(cut), + cut = cut?.let { createClassRepresentation(it) }, classesToTest = classesToTest.map(this::createClassRepresentation).toList(), polymorphismRelations = polymorphismRelations, promptConfiguration = PromptConfiguration( @@ -110,7 +115,12 @@ class PromptManager( .map(this::createClassRepresentation) .toList() - promptGenerator.generatePromptForMethod(method, interestingClassesFromMethod, testSamplesCode) + promptGenerator.generatePromptForMethod( + method, + interestingClassesFromMethod, + testSamplesCode, + psiHelper.getPackageName(), + ) } CodeType.LINE -> { @@ -118,7 +128,7 @@ class PromptManager( val psiMethod = getPsiMethod(cut, getMethodDescriptor(cut, lineNumber))!! // get code of line under test - val document = PsiDocumentManager.getInstance(project).getDocument(cut.containingFile) + val document = psiHelper.getDocumentFromPsiFile() val lineStartOffset = document!!.getLineStartOffset(lineNumber - 1) val lineEndOffset = document.getLineEndOffset(lineNumber - 1) @@ -149,7 +159,7 @@ class PromptManager( signature = psiMethod.signature, name = psiMethod.name, text = psiMethod.text!!, - containingClassQualifiedName = psiMethod.containingClass!!.qualifiedName, + containingClassQualifiedName = psiMethod.containingClass?.qualifiedName ?: "", ) } @@ -210,7 +220,6 @@ class PromptManager( * * @param project The project context in which the PsiClasses exist. * @param interestingPsiClasses The set of PsiClassWrappers that are considered interesting. - * @param cutPsiClass The cut PsiClassWrapper to determine polymorphism relations against. * @return A mutable map where the key represents an interesting PsiClass and the value is a list of its detected subclasses. */ private fun getPolymorphismRelationsWithQualifiedNames( @@ -219,6 +228,9 @@ class PromptManager( ): MutableMap> { val polymorphismRelations: MutableMap> = mutableMapOf() + // assert(interestingPsiClasses.isEmpty()) + if (cut == null) return polymorphismRelations + interestingPsiClasses.add(cut) interestingPsiClasses.forEach { currentInterestingClass -> @@ -245,9 +257,14 @@ class PromptManager( * @return The matching PsiMethod if found, otherwise an empty string. */ private fun getPsiMethod( - psiClass: PsiClassWrapper, + psiClass: PsiClassWrapper?, methodDescriptor: String, ): PsiMethodWrapper? { + // Processing function outside the class + if (psiClass == null) { + val currentPsiMethod = psiHelper.getSurroundingMethod(caret)!! + return currentPsiMethod + } for (currentPsiMethod in psiClass.allMethods) { val file = psiClass.containingFile val psiHelper = PsiHelperProvider.getPsiHelper(file) @@ -268,9 +285,14 @@ class PromptManager( * @return the method descriptor as a String, or an empty string if no method is found */ private fun getMethodDescriptor( - psiClass: PsiClassWrapper, + psiClass: PsiClassWrapper?, lineNumber: Int, ): String { + // Processing function outside the class + if (psiClass == null) { + val currentPsiMethod = psiHelper.getSurroundingMethod(caret)!! + return psiHelper.generateMethodDescriptor(currentPsiMethod) + } for (currentPsiMethod in psiClass.allMethods) { if (currentPsiMethod.containsLine(lineNumber)) { val file = psiClass.containingFile diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt index 46daefc30..f05d55986 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt @@ -6,6 +6,7 @@ import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieRequestManager +import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFaceRequestManager import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIRequestManager interface RequestManagerFactory { @@ -20,6 +21,7 @@ class StandardRequestManagerFactory(private val project: Project) : RequestManag return when (val platform = LlmSettingsArguments(project).currentLLMPlatformName()) { llmSettingsState.openAIName -> OpenAIRequestManager(project) llmSettingsState.grazieName -> GrazieRequestManager(project) + llmSettingsState.huggingFaceName -> HuggingFaceRequestManager(project) else -> throw IllegalStateException("Unknown selected platform: $platform") } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt index c2267beb8..45581b8cf 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt @@ -62,14 +62,12 @@ class GrazieRequestManager(project: Project) : IJRequestManager(project) { } private fun getMessages(): List> { - val result = mutableListOf>() - chatHistory.forEach { + return chatHistory.map { val role = when (it.role) { ChatMessage.ChatRole.User -> "user" ChatMessage.ChatRole.Assistant -> "assistant" } - result.add(Pair(role, it.content)) + (role to it.content) } - return result } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFacePlatform.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFacePlatform.kt new file mode 100644 index 000000000..e5b93f588 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFacePlatform.kt @@ -0,0 +1,9 @@ +package org.jetbrains.research.testspark.tools.llm.generation.hf + +import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform + +class HuggingFacePlatform( + override val name: String = "HuggingFace", + override var token: String = "", + override var model: String = "", +) : LLMPlatform diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt new file mode 100644 index 000000000..6ef09950f --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt @@ -0,0 +1,33 @@ +package org.jetbrains.research.testspark.tools.llm.generation.hf + +import org.jetbrains.research.testspark.core.data.ChatMessage + +data class Parameters( + val topProbability: Double, + val temperature: Double, +) + +data class HuggingFaceRequestBody( + val messages: List, + val parameters: Parameters, +) + +/** + * Sets LLM settings required to send inference requests to HF + * For more info, see https://huggingface.co/docs/api-inference/en/detailed_parameters + */ +fun HuggingFaceRequestBody.toMap(): Map { + return mapOf( + "inputs" to this.messages.joinToString(separator = "\n") { it.content }, + // TODO: These parameters can be set by the user in the plugin's settings too. + "parameters" to mapOf( + "top_p" to this.parameters.topProbability, + "temperature" to this.parameters.temperature, + "min_length" to 4096, + "max_length" to 8192, + "max_new_tokens" to 250, + "max_time" to 120.0, + "return_full_text" to false, + ), + ) +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt new file mode 100644 index 000000000..e99a25bf2 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt @@ -0,0 +1,116 @@ +package org.jetbrains.research.testspark.tools.llm.generation.hf + +import com.google.gson.GsonBuilder +import com.google.gson.JsonParser +import com.intellij.openapi.project.Project +import com.intellij.util.io.HttpRequests +import com.intellij.util.io.HttpRequests.HttpStatusException +import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle +import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle +import org.jetbrains.research.testspark.core.data.ChatUserMessage +import org.jetbrains.research.testspark.core.monitor.ErrorMonitor +import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.TestsAssembler +import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments +import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager +import org.jetbrains.research.testspark.tools.llm.generation.IJRequestManager +import java.net.HttpURLConnection + +/** + * A class to manage requests sent to large language models hosted on HuggingFace + */ +class HuggingFaceRequestManager(project: Project) : IJRequestManager(project) { + private val url = "https://api-inference.huggingface.co/models/meta-llama/" + + // TODO: The user should be able to change these numbers in the plugin's settings + private val topProbability = 0.9 + private val temperature = 0.9 + + private val llmErrorManager = LLMErrorManager() + + override fun send( + prompt: String, + indicator: CustomProgressIndicator, + testsAssembler: TestsAssembler, + errorMonitor: ErrorMonitor, + ): SendResult { + val httpRequest = HttpRequests.post( + url + LlmSettingsArguments(project).getModel(), + "application/json", + ).tuner { + it.setRequestProperty("Authorization", "Bearer $token") + } + + // Add system prompt + if (chatHistory.size == 1) { + chatHistory[0] = ChatUserMessage( + createInstructionPrompt( + chatHistory[0].content, + ), + ) + } + + val llmRequestBody = HuggingFaceRequestBody(chatHistory, Parameters(topProbability, temperature)).toMap() + var sendResult = SendResult.OK + try { + httpRequest.connect { + it.write(GsonBuilder().disableHtmlEscaping().create().toJson(llmRequestBody)) + when (val responseCode = (it.connection as HttpURLConnection).responseCode) { + HttpURLConnection.HTTP_OK -> { + val text = it.reader.readLine() + val generatedTestCases = extractLLMGeneratedCode( + JsonParser.parseString(text).asJsonArray[0] + .asJsonObject["generated_text"].asString.trim(), + ) + testsAssembler.consume(generatedTestCases) + } + + HttpURLConnection.HTTP_INTERNAL_ERROR -> { + llmErrorManager.errorProcess( + LLMMessagesBundle.get("serverProblems"), + project, + errorMonitor, + ) + sendResult = SendResult.OTHER + } + + HttpURLConnection.HTTP_BAD_REQUEST -> { + llmErrorManager.errorProcess( + LLMMessagesBundle.get("hfServerError"), + project, + errorMonitor, + ) + sendResult = SendResult.OTHER + } + } + } + } catch (e: HttpStatusException) { + log.error { "Error in sending request: ${e.message}" } + } + return sendResult + } + + /** + * Creates the required prompt for Llama models. For more details see: + * https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + */ + private fun createInstructionPrompt(userMessage: String): String { + // TODO: This is Llama-specific and should support other LLMs hosted on HF too. + return "[INST] <> ${LLMDefaultsBundle.get("huggingFaceInitialSystemPrompt")} <> $userMessage [/INST]" + } + + /** + * Extracts code blocks in LLMs' response. + * Also, it handles the cases where the LLM-generated code does not end with ``` + */ + private fun extractLLMGeneratedCode(text: String): String { + // TODO: This method should support other languages other than Java. + val modifiedText = text.replace("```java", "```").replace("````", "```") + val tripleTickBlockIndex = modifiedText.indexOf("```") + val codePart = modifiedText.substring(tripleTickBlockIndex + 3) + val lines = codePart.lines() + val filteredLines = lines.filter { line -> line != "```" } + val code = filteredLines.joinToString("\n") + return "```\n$code\n```" + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt index 40e0c3fba..33138c4f8 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt @@ -1,9 +1,30 @@ package org.jetbrains.research.testspark.tools.llm.generation.openai -import org.jetbrains.research.testspark.core.data.ChatMessage +/** + * Adheres the naming of fields for OpenAI chat completion API and checks the correctness of a `role`. + *
+ * Use this class as a carrier of messages that should be sent to OpenAI API. + */ +data class OpenAIChatMessage(val role: String, val content: String) { + private companion object { + /** + * The API strictly defines the set of roles. + * The `function` role is omitted because it is already deprecated. + * + * See: https://platform.openai.com/docs/api-reference/chat/create + */ + val supportedRoles = listOf("user", "assistant", "system", "tool") + } + + init { + if (!supportedRoles.contains(role)) { + throw IllegalArgumentException("'$role' is not supported ${OpenAIChatMessage::class}. Available roles are: ${(supportedRoles.joinToString(", ") { "'$it'" })}") + } + } +} data class OpenAIRequestBody( val model: String, - val messages: List, + val messages: List, val stream: Boolean = true, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt index ed6607d3e..1d9d6a9a4 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt @@ -7,6 +7,7 @@ import com.intellij.openapi.project.Project import com.intellij.util.io.HttpRequests import com.intellij.util.io.HttpRequests.HttpStatusException import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle +import org.jetbrains.research.testspark.core.data.ChatMessage import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.TestsAssembler @@ -35,22 +36,29 @@ class OpenAIRequestManager(project: Project) : IJRequestManager(project) { errorMonitor: ErrorMonitor, ): SendResult { // Prepare the chat - val llmRequestBody = OpenAIRequestBody(LlmSettingsArguments(project).getModel(), chatHistory) + val messages = chatHistory.map { + val role = when (it.role) { + ChatMessage.ChatRole.User -> "user" + ChatMessage.ChatRole.Assistant -> "assistant" + } + OpenAIChatMessage(role, it.content) + } + + val llmRequestBody = OpenAIRequestBody(LlmSettingsArguments(project).getModel(), messages) var sendResult = SendResult.OK try { - httpRequest.connect { - it.write(GsonBuilder().create().toJson(llmRequestBody)) + httpRequest.connect { request -> + // send request to OpenAI API + request.write(GsonBuilder().create().toJson(llmRequestBody)) + + val connection = request.connection as HttpURLConnection // check response - when (val responseCode = (it.connection as HttpURLConnection).responseCode) { + when (val responseCode = connection.responseCode) { HttpURLConnection.HTTP_OK -> { - assembleLlmResponse( - httpRequest = it, - indicator, - testsAssembler, - ) + assembleLlmResponse(request, testsAssembler, indicator) } HttpURLConnection.HTTP_INTERNAL_ERROR -> { @@ -105,13 +113,12 @@ class OpenAIRequestManager(project: Project) : IJRequestManager(project) { */ private fun assembleLlmResponse( httpRequest: HttpRequests.Request, - indicator: CustomProgressIndicator, testsAssembler: TestsAssembler, + indicator: CustomProgressIndicator, ) { while (true) { if (ToolUtils.isProcessCanceled(indicator)) return - Thread.sleep(50L) var text = httpRequest.reader.readLine() if (text.isEmpty()) continue diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt index b1473b0c9..10aded741 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt @@ -3,12 +3,14 @@ package org.jetbrains.research.testspark.tools.llm.test import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper +import org.jetbrains.research.testspark.tools.TestClassCodeGeneratorFactory class JUnitTestSuitePresenter( private val project: Project, private val generatedTestsData: TestGenerationData, + private val language: SupportedLanguage, ) { /** * Returns a string representation of this object. @@ -34,12 +36,12 @@ class JUnitTestSuitePresenter( // Add each test testCases.forEach { testCase -> testBody += "$testCase\n" } - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, testFileName, testBody, imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -57,12 +59,12 @@ class JUnitTestSuitePresenter( testCaseIndex: Int, ): String = testSuite.run { - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, getClassWithTestCaseName(testCases[testCaseIndex].name), testCases[testCaseIndex].toStringWithoutExpectedException() + "\n", imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -81,12 +83,12 @@ class JUnitTestSuitePresenter( // Add each test (exclude expected exception) testCases.forEach { testCase -> testBody += "${testCase.toStringWithoutExpectedException()}\n" } - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, testFileName, testBody, imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -105,8 +107,8 @@ class JUnitTestSuitePresenter( fun getPrintablePackageString(testSuite: TestSuiteGeneratedByLLM): String { return testSuite.run { when { - packageString.isEmpty() || packageString.isBlank() -> "" - else -> packageString + packageName.isEmpty() || packageName.isBlank() -> "" + else -> packageName } } } diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index 75aa7c65b..dffd2b46d 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -1,6 +1,6 @@ - org.jetbrains.research.testspark + org.jetbrains.research.testgenie TestSpark ictl diff --git a/src/main/resources/properties/llm/LLMDefaults.properties b/src/main/resources/properties/llm/LLMDefaults.properties index 156f15cbd..1eddae6e2 100644 --- a/src/main/resources/properties/llm/LLMDefaults.properties +++ b/src/main/resources/properties/llm/LLMDefaults.properties @@ -4,6 +4,10 @@ openAIModel= grazieName=AI Assistant JetBrains grazieToken= grazieModel= +huggingFaceName=HuggingFace +huggingFaceToken= +huggingFaceModel= +huggingFaceInitialSystemPrompt=You are a helpful and honest code and programming assistant. Please, respond concisely and truthfully. maxLLMRequest=3 maxInputParamsDepth=2 maxPolyDepth=2 diff --git a/src/main/resources/properties/llm/LLMMessages.properties b/src/main/resources/properties/llm/LLMMessages.properties index db087d5c1..3502840ab 100644 --- a/src/main/resources/properties/llm/LLMMessages.properties +++ b/src/main/resources/properties/llm/LLMMessages.properties @@ -14,4 +14,5 @@ grazieError=Grazie test generation feature is not available in this build. removeTemplateMessage=Choose another default template to remove this one. removeTemplateTitle=Can't Be Removed defaultPromptIsNotValidMessage=Default prompt is not valid. Fix it, please. -defaultPromptIsNotValidTitle=Incorrect Prompt State \ No newline at end of file +defaultPromptIsNotValidTitle=Incorrect Prompt State +hfServerError=The selected model may need an HF PRO subscription to use! \ No newline at end of file diff --git a/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt index 934a4fac5..a05013d13 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt @@ -215,4 +215,4 @@ class SettingsArgumentsLlmEvoSuiteTest { criterion, ).isEqualTo("-Dcriterion=LINE:BRANCH:EXCEPTION:WEAKMUTATION:OUTPUT:METHOD:METHODNOEXCEPTION:CBRANCH") } -} \ No newline at end of file +} From 3610571f361979da30a72f23595d60b49d49810d Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 20:47:50 +0200 Subject: [PATCH 07/19] renaming of getSurroundingLine --- .../org/jetbrains/research/testspark/java/JavaPsiHelper.kt | 4 ++-- .../jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt | 4 ++-- .../research/testspark/langwrappers/PsiComponents.kt | 2 +- .../jetbrains/research/testspark/tools/evosuite/EvoSuite.kt | 2 +- .../kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt index d2b8dac35..f6f132a29 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt @@ -67,7 +67,7 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { return null } - override fun getSurroundingLine(caretOffset: Int): Int? { + override fun getSurroundingLineNumber(caretOffset: Int): Int? { val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null val selectedLine = doc.getLineNumber(caretOffset) @@ -158,7 +158,7 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { val javaPsiClassWrapped = getSurroundingClass(caret.offset) as JavaPsiClassWrapper? val javaPsiMethodWrapped = getSurroundingMethod(caret.offset) as JavaPsiMethodWrapper? - val line: Int? = getSurroundingLine(caret.offset) + val line: Int? = getSurroundingLineNumber(caret.offset) javaPsiClassWrapped?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } javaPsiMethodWrapped?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt index ca131f7da..333132a14 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt @@ -70,7 +70,7 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { return null } - override fun getSurroundingLine(caretOffset: Int): Int? { + override fun getSurroundingLineNumber(caretOffset: Int): Int? { val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null val selectedLine = doc.getLineNumber(caretOffset) @@ -165,7 +165,7 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { val ktClass = getSurroundingClass(caret.offset) val ktFunction = getSurroundingMethod(caret.offset) - val line: Int? = getSurroundingLine(caret.offset)?.plus(1) + val line: Int? = getSurroundingLineNumber(caret.offset)?.plus(1) ktClass?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } ktFunction?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt index c6f98afeb..0aa5dfd0f 100644 --- a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt @@ -114,7 +114,7 @@ interface PsiHelper { * @param caretOffset The caret offset within the PSI file. * @return The line number of the selected line, otherwise null. */ - fun getSurroundingLine(caretOffset: Int): Int? + fun getSurroundingLineNumber(caretOffset: Int): Int? /** * Retrieves a set of interesting PsiClasses based on a given project, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt index 4e4c75a75..529bb4b8e 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt @@ -88,7 +88,7 @@ class EvoSuite(override val name: String = "EvoSuite") : Tool { */ override fun generateTestsForLine(project: Project, psiHelper: PsiHelper, caretOffset: Int, fileUrl: String?, testSamplesCode: String, testGenerationController: TestGenerationController) { log.info("Starting tests generation for line by EvoSuite") - val selectedLine: Int = psiHelper.getSurroundingLine(caretOffset)!! + val selectedLine: Int = psiHelper.getSurroundingLineNumber(caretOffset)!! createPipeline(project, psiHelper, caretOffset, fileUrl, testGenerationController).runTestGeneration( getEvoSuiteProcessManager(project), FragmentToTestData( diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt index 89a27df64..980707a2a 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt @@ -151,7 +151,7 @@ class Llm(override val name: String = "LLM") : Tool { testGenerationController.finished() return } - val selectedLine: Int = psiHelper.getSurroundingLine(caretOffset)!! + val selectedLine: Int = psiHelper.getSurroundingLineNumber(caretOffset)!! val codeType = FragmentToTestData(CodeType.LINE, selectedLine) createLLMPipeline(project, psiHelper, caretOffset, fileUrl, testGenerationController).runTestGeneration( LLMProcessManager( From 4d5f39d4f8e210028198f04340ebd2523ca8e18e Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 20:50:21 +0200 Subject: [PATCH 08/19] fixing compilation error --- .../services/kotlin/KotlinTestCaseDisplayService.kt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt index a80952747..a77edd16d 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt @@ -39,7 +39,8 @@ import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.display.TestCasePanelFactory import org.jetbrains.research.testspark.display.TopButtonsPanelFactory import org.jetbrains.research.testspark.helpers.ReportHelper -import org.jetbrains.research.testspark.helpers.kotlin.KotlinClassBuilderHelper +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeGenerator import org.jetbrains.research.testspark.kotlin.KotlinPsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.services.CoverageVisualisationService @@ -417,8 +418,8 @@ class KotlinTestCaseDisplayService(private val project: Project) : TestCaseDispl // insert tests to a code testCaseComponents.reversed().forEach { val testMethodCode = - KotlinClassBuilderHelper.extractFirstTestMethodCode( - KotlinClassBuilderHelper.formatCode( + KotlinTestClassCodeAnalyzer.extractFirstTestMethodCode( + KotlinTestClassCodeGenerator.formatCode( project, it.replace("\r\n", "\n") .replace("verifyException(", "// verifyException("), From e54c0449ff7dc99e7fdb5dc48234150040c02a07 Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 21:02:03 +0200 Subject: [PATCH 09/19] fix: problem one with wrong cast --- build.gradle.kts | 1 + .../testspark/core/data/TestGenerationData.kt | 4 +- .../generation/llm/LLMWithFeedbackCycle.kt | 14 +- .../testspark/core/generation/llm/Utils.kt | 48 +- .../generation/llm/network/RequestManager.kt | 12 +- .../generation/llm/prompt/PromptBuilder.kt | 13 +- .../generation/llm/prompt/PromptGenerator.kt | 6 +- .../llm/prompt/configuration/Configuration.kt | 5 +- .../testspark/core/test/SupportedLanguage.kt | 8 + .../testspark/core/test/TestBodyPrinter.kt | 21 + .../testspark/core/test/TestCompiler.kt | 60 +- .../test/{parsers => }/TestSuiteParser.kt | 4 +- .../testspark/core/test/TestsAssembler.kt | 8 +- .../core/test/TestsPersistentStorage.kt | 1 + .../testspark/core/test}/data/CodeType.kt | 2 +- .../core/test/data/TestCaseGeneratedByLLM.kt | 29 +- .../core/test/data/TestSuiteGeneratedByLLM.kt | 4 +- ...cies.kt => TestCompilationDependencies.kt} | 2 +- .../test/java/JavaJUnitTestSuiteParser.kt | 32 + .../core/test/java/JavaTestBodyPrinter.kt | 40 ++ .../core/test/java/JavaTestCompiler.kt | 53 ++ .../test/kotlin/KotlinJUnitTestSuiteParser.kt | 32 + .../core/test/kotlin/KotlinTestBodyPrinter.kt | 40 ++ .../core/test/kotlin/KotlinTestCompiler.kt | 31 + .../parsers/java/JavaJUnitTestSuiteParser.kt | 22 - .../kotlin/KotlinJUnitTestSuiteParser.kt | 22 - .../JUnitTestSuiteParserStrategy.kt | 173 ------ .../JUnitTestSuiteParserStrategy.kt | 175 ++++++ .../research/testspark/core/utils/Language.kt | 8 - .../research/testspark/core/utils/Patterns.kt | 10 +- .../kotlin/KotlinJUnitTestSuiteParserTest.kt | 161 ++++- .../testspark/java/JavaPsiClassWrapper.kt | 32 +- .../research/testspark/java/JavaPsiHelper.kt | 57 +- .../testspark/kotlin/KotlinPsiClassWrapper.kt | 40 +- .../testspark/kotlin/KotlinPsiHelper.kt | 90 ++- .../kotlin/KotlinPsiMethodWrapper.kt | 20 + langwrappers/build.gradle.kts | 2 - .../LanguageClassTextExtractor.kt | 7 + .../testspark/langwrappers/PsiComponents.kt | 42 +- .../JavaKotlinClassTextExtractor.kt | 39 ++ .../testspark/actions/TestSparkAction.kt | 87 +-- .../actions/llm/LLMSampleSelectorFactory.kt | 5 +- .../actions/llm/LLMSetupPanelFactory.kt | 6 +- .../actions/llm/TestSamplePanelFactory.kt | 4 +- .../testspark/appstarter/TestSparkStarter.kt | 14 +- .../testspark/data/FragmentToTestData.kt | 2 + .../testspark/display/TestCasePanelFactory.kt | 53 +- .../display/TopButtonsPanelFactory.kt | 70 +-- .../strategies/TopButtonsPanelStrategy.kt | 138 +++++ .../testspark/helpers/CoverageHelper.kt | 6 +- .../helpers/JavaClassBuilderHelper.kt | 204 ------- .../research/testspark/helpers/LLMHelper.kt | 55 +- .../helpers/TestClassCodeAnalyzer.kt | 39 ++ .../helpers/TestClassCodeGenerator.kt | 43 ++ .../helpers/java/JavaTestClassCodeAnalyzer.kt | 78 +++ .../java/JavaTestClassCodeGenerator.kt | 104 ++++ .../kotlin/KotlinTestClassCodeAnalyzer.kt | 65 ++ .../kotlin/KotlinTestClassCodeGenerator.kt | 101 ++++ .../CoverageToolWindowDisplayService.kt | 0 .../services/TestCaseDisplayService.kt | 527 +---------------- .../java/JavaTestCaseDisplayService.kt | 544 +++++++++++++++++ .../kotlin/KotlinTestCaseDisplayService.kt | 554 ++++++++++++++++++ .../settings/llm/LLMSettingsComponent.kt | 2 +- .../settings/llm/LLMSettingsConfigurable.kt | 12 + .../settings/llm/LLMSettingsState.kt | 6 + .../testspark/tools/LibraryPathsProvider.kt | 4 +- .../research/testspark/tools/Pipeline.kt | 45 +- .../testspark/tools/TestBodyPrinterFactory.kt | 17 + .../tools/TestClassCodeAnalyzerFactory.kt | 21 + .../tools/TestClassCodeGeneratorFactory.kt | 21 + .../testspark/tools/TestCompilerFactory.kt | 17 +- .../research/testspark/tools/TestProcessor.kt | 61 +- .../testspark/tools/TestSuiteParserFactory.kt | 31 + .../testspark/tools/TestsAssemblerFactory.kt | 18 + .../research/testspark/tools/ToolUtils.kt | 45 +- .../testspark/tools/evosuite/EvoSuite.kt | 4 +- .../generation/EvoSuiteProcessManager.kt | 9 +- .../research/testspark/tools/llm/Llm.kt | 14 +- .../tools/llm/LlmSettingsArguments.kt | 2 + .../llm/generation/JUnitTestsAssembler.kt | 35 +- .../tools/llm/generation/LLMProcessManager.kt | 76 ++- .../tools/llm/generation/PromptManager.kt | 44 +- .../llm/generation/RequestManagerFactory.kt | 2 + .../generation/grazie/GrazieRequestManager.kt | 6 +- .../llm/generation/hf/HuggingFacePlatform.kt | 9 + .../generation/hf/HuggingFaceRequestBody.kt | 33 ++ .../hf/HuggingFaceRequestManager.kt | 116 ++++ .../generation/openai/OpenAIRequestBody.kt | 25 +- .../generation/openai/OpenAIRequestManager.kt | 29 +- .../tools/llm/test/JUnitTestSuitePresenter.kt | 20 +- .../properties/llm/LLMDefaults.properties | 4 + .../properties/llm/LLMMessages.properties | 3 +- 92 files changed, 3318 insertions(+), 1482 deletions(-) create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt rename core/src/main/kotlin/org/jetbrains/research/testspark/core/test/{parsers => }/TestSuiteParser.kt (87%) rename {src/main/kotlin/org/jetbrains/research/testspark => core/src/main/kotlin/org/jetbrains/research/testspark/core/test}/data/CodeType.kt (73%) rename core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/{JavaTestCompilationDependencies.kt => TestCompilationDependencies.kt} (96%) create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt create mode 100644 langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt create mode 100644 langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/display/strategies/TopButtonsPanelStrategy.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestSuiteParserFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFacePlatform.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt diff --git a/build.gradle.kts b/build.gradle.kts index 13da233c4..5e6621e29 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -157,6 +157,7 @@ dependencies { // https://mvnrepository.com/artifact/org.mockito/mockito-all testImplementation("org.mockito:mockito-all:1.10.19") + testImplementation("org.mockito.kotlin:mockito-kotlin:5.1.0") // https://mvnrepository.com/artifact/net.jqwik/jqwik testImplementation("net.jqwik:jqwik:1.6.5") diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt index d11f346d5..a35212cb1 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt @@ -16,7 +16,7 @@ data class TestGenerationData( // Code required of imports and package for generated tests var importsCode: MutableSet = mutableSetOf(), - var packageLine: String = "", + var packageName: String = "", var runWith: String = "", var otherInfo: String = "", @@ -37,7 +37,7 @@ data class TestGenerationData( resultName = "" fileUrl = "" importsCode = mutableSetOf() - packageLine = "" + packageName = "" runWith = "" otherInfo = "" polyDepthReducing = 0 diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt index 0c8a428aa..973b26e7a 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt @@ -10,13 +10,13 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeRed import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.test.TestsPresenter import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language import java.io.File enum class FeedbackCycleExecutionResult { @@ -45,7 +45,7 @@ data class FeedbackResponse( class LLMWithFeedbackCycle( private val report: Report, - private val language: Language, + private val language: SupportedLanguage, private val initialPromptMessage: String, private val promptSizeReductionStrategy: PromptSizeReductionStrategy, // filename in which the test suite is saved in result path @@ -167,13 +167,15 @@ class LLMWithFeedbackCycle( generatedTestSuite.updateTestCases(compilableTestCases.toMutableList()) } else { for (testCaseIndex in generatedTestSuite.testCases.indices) { - val testCaseFilename = - "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" + val testCaseFilename = when (language) { + SupportedLanguage.Java -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" + SupportedLanguage.Kotlin -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.kt" + } val testCaseRepresentation = testsPresenter.representTestCase(generatedTestSuite, testCaseIndex) val saveFilepath = testStorage.saveGeneratedTest( - generatedTestSuite.packageString, + generatedTestSuite.packageName, testCaseRepresentation, resultPath, testCaseFilename, @@ -184,7 +186,7 @@ class LLMWithFeedbackCycle( } val generatedTestSuitePath: String = testStorage.saveGeneratedTest( - generatedTestSuite.packageString, + generatedTestSuite.packageName, testsPresenter.representTestSuite(generatedTestSuite), resultPath, testSuiteFilename, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt index 76cb74c17..1942a6a86 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt @@ -4,13 +4,47 @@ import org.jetbrains.research.testspark.core.generation.llm.network.RequestManag import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.utils.javaPackagePattern +import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern import java.util.Locale // TODO: find a better place for the below functions +/** + * Retrieves the package declaration from the given test suite code for any language. + * + * @param testSuiteCode The generated code of the test suite. + * @return The package name extracted from the test suite code, or an empty string if no package declaration was found. + */ +fun getPackageFromTestSuiteCode(testSuiteCode: String?, language: SupportedLanguage): String { + testSuiteCode ?: return "" + return when (language) { + SupportedLanguage.Kotlin -> kotlinPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty() + SupportedLanguage.Java -> javaPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty() + } +} + +/** + * Retrieves the imports code from a given test suite code. + * + * @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned. + * @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result. + * @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned. + */ +fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String?): MutableSet { + testSuiteCode ?: return mutableSetOf() + return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence() + .filter { it.contains("^import".toRegex()) } + .filterNot { it.contains("evosuite".toRegex()) } + .filterNot { it.contains("RunWith".toRegex()) } + // classFQN will be null for the top level function + .filterNot { classFQN != null && it.contains(classFQN.toRegex()) } + .toMutableSet() +} + /** * Returns the generated class name for a given test case. * @@ -39,7 +73,7 @@ fun getClassWithTestCaseName(testCaseName: String): String { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun executeTestCaseModificationRequest( - language: Language, + language: SupportedLanguage, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -50,15 +84,7 @@ fun executeTestCaseModificationRequest( // Update Token information val prompt = "For this test:\n ```\n $testCase\n ```\nPerform the following task: $task" - var packageName = "" - testCase.split("\n")[0].let { - if (it.startsWith("package")) { - packageName = it - .removePrefix("package ") - .removeSuffix(";") - .trim() - } - } + val packageName = getPackageFromTestSuiteCode(testCase, language) val response = requestManager.request( language, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt index 689eec798..441e51231 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt @@ -7,8 +7,8 @@ import org.jetbrains.research.testspark.core.data.ChatUserMessage import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestsAssembler -import org.jetbrains.research.testspark.core.utils.Language abstract class RequestManager(var token: String) { enum class SendResult { @@ -31,7 +31,7 @@ abstract class RequestManager(var token: String) { * @return the generated TestSuite, or null and prompt message */ open fun request( - language: Language, + language: SupportedLanguage, prompt: String, indicator: CustomProgressIndicator, packageName: String, @@ -65,7 +65,7 @@ abstract class RequestManager(var token: String) { open fun processResponse( testsAssembler: TestsAssembler, packageName: String, - language: Language, + language: SupportedLanguage, ): LLMResponse { // save the full response in the chat history val response = testsAssembler.getContent() @@ -78,7 +78,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite() return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) @@ -97,7 +97,7 @@ abstract class RequestManager(var token: String) { open fun processUserFeedbackResponse( testsAssembler: TestsAssembler, packageName: String, - language: Language, + language: SupportedLanguage, ): LLMResponse { val response = testsAssembler.getContent() @@ -108,7 +108,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite() return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 278d58655..036e87a0d 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -78,7 +78,7 @@ internal class PromptBuilder(private var prompt: String) { fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" } for (interestingClass in interestingClasses) { - if (interestingClass.qualifiedName.startsWith("java")) { + if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) { continue } @@ -88,7 +88,9 @@ internal class PromptBuilder(private var prompt: String) { // Skip java methods // TODO: checks for java methods should be done by a caller to make // this class as abstract and language agnostic as possible. - if (method.containingClassQualifiedName.startsWith("java")) { + if (method.containingClassQualifiedName.startsWith("java") || + method.containingClassQualifiedName.startsWith("kotlin") + ) { continue } @@ -106,8 +108,11 @@ internal class PromptBuilder(private var prompt: String) { ) = apply { val keyword = "\$${PromptKeyword.POLYMORPHISM.text}" if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) { - var fullText = "" - + // If polymorphismRelations is not empty, we add an instruction to avoid mocking classes if an instantiation of a sub-class is applicable + var fullText = when { + polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable" + else -> "" + } polymorphismRelations.forEach { entry -> for (currentSubClass in entry.value) { val subClassTypeName = when (currentSubClass.classType) { diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt index 3afbd3cff..72340867a 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt @@ -19,7 +19,7 @@ class PromptGenerator( fun generatePromptForClass(interestingClasses: List, testSamplesCode: String): String { val prompt = PromptBuilder(promptTemplates.classPrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName(context.cut.qualifiedName) + .insertName(context.cut!!.qualifiedName) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) .insertCodeUnderTest(context.cut.fullText, context.classesToTest) @@ -44,10 +44,12 @@ class PromptGenerator( method: MethodRepresentation, interestingClassesFromMethod: List, testSamplesCode: String, + packageName: String, ): String { + val name = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}" val prompt = PromptBuilder(promptTemplates.methodPrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName("${context.cut.qualifiedName}.${method.name}") + .insertName(name) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) .insertCodeUnderTest(method.text, context.classesToTest) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt index 4094de1aa..6b87e8941 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt @@ -10,7 +10,10 @@ import org.jetbrains.research.testspark.core.data.ClassType * @property polymorphismRelations A map where the key represents a ClassRepresentation object and the value is a list of its detected subclasses. */ data class PromptGenerationContext( - val cut: ClassRepresentation, + /** + * The cut is null when we want to generate tests for top-level function + */ + val cut: ClassRepresentation?, val classesToTest: List, val polymorphismRelations: Map>, val promptConfiguration: PromptConfiguration, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt new file mode 100644 index 000000000..4b4de90c8 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt @@ -0,0 +1,8 @@ +package org.jetbrains.research.testspark.core.test + +/** + * Language ID string should be the same as the language name in com.intellij.lang.Language + */ +enum class SupportedLanguage(val languageId: String) { + Java("JAVA"), Kotlin("kotlin") +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt new file mode 100644 index 000000000..450400ac3 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestBodyPrinter.kt @@ -0,0 +1,21 @@ +package org.jetbrains.research.testspark.core.test + +import org.jetbrains.research.testspark.core.test.data.TestLine + +interface TestBodyPrinter { + /** + * Generates a test body as a string based on the provided parameters. + * + * @param testInitiatedText A string containing the upper part of the test case. + * @param lines A mutable list of `TestLine` objects representing the lines of the test body. + * @param throwsException The exception type that the test function throws, if any. + * @param name The name of the test function. + * @return A string representing the complete test body. + */ + fun printTestBody( + testInitiatedText: String, + lines: MutableList, + throwsException: String, + name: String, + ): String +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt index bc4d40617..b49281aaf 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt @@ -1,32 +1,24 @@ package org.jetbrains.research.testspark.core.test -import io.github.oshai.kotlinlogging.KotlinLogging import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.CommandLineRunner import org.jetbrains.research.testspark.core.utils.DataFilesUtil -import java.io.File data class TestCasesCompilationResult( val allTestCasesCompilable: Boolean, val compilableTestCases: MutableSet, ) -/** - * TestCompiler is a class that is responsible for compiling generated test cases using the proper javac. - * It provides methods for compiling test cases and code files. - */ -open class TestCompiler( - private val javaHomeDirectoryPath: String, +abstract class TestCompiler( private val libPaths: List, private val junitLibPaths: List, ) { - private val log = KotlinLogging.logger { this::class.java } - /** - * Compiles the generated files with test cases using the proper javac. + * Compiles a list of test cases and returns the compilation result. * - * @return true if all the provided test cases are successfully compiled, - * otherwise returns false. + * @param generatedTestCasesPaths A list of file paths where the generated test cases are located. + * @param buildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case. + * @param testCases A mutable list of `TestCaseGeneratedByLLM` objects representing the test cases to be compiled. + * @return A `TestCasesCompilationResult` object containing the overall compilation success status and a set of compilable test cases. */ fun compileTestCases( generatedTestCasesPaths: List, @@ -51,45 +43,11 @@ open class TestCompiler( * Compiles the code at the specified path using the provided project build path. * * @param path The path of the code file to compile. - * @param projectBuildPath The project build path to use during compilation. + * @param projectBuildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case. * @return A pair containing a boolean value indicating whether the compilation was successful (true) or not (false), * and a string message describing any error encountered during compilation. */ - fun compileCode(path: String, projectBuildPath: String): Pair { - // find the proper javac - val javaCompile = File(javaHomeDirectoryPath).walk() - .filter { - val isCompilerName = if (DataFilesUtil.isWindows()) it.name.equals("javac.exe") else it.name.equals("javac") - isCompilerName && it.isFile - } - .firstOrNull() - - if (javaCompile == null) { - val msg = "Cannot find java compiler 'javac' at '$javaHomeDirectoryPath'" - log.error { msg } - throw RuntimeException(msg) - } - - println("javac found at '${javaCompile.absolutePath}'") - - // compile file - val errorMsg = CommandLineRunner.run( - arrayListOf( - javaCompile.absolutePath, - "-cp", - "\"${getPath(projectBuildPath)}\"", - path, - ), - ) - - log.info { "Error message: '$errorMsg'" } - - // create .class file path - val classFilePath = path.replace(".java", ".class") - - // check is .class file exists - return Pair(File(classFilePath).exists(), errorMsg) - } + abstract fun compileCode(path: String, projectBuildPath: String): Pair /** * Generates the path for the command by concatenating the necessary paths. @@ -97,7 +55,7 @@ open class TestCompiler( * @param buildPath The path of the build file. * @return The generated path as a string. */ - fun getPath(buildPath: String): String { + fun getClassPaths(buildPath: String): String { // create the path for the command val separator = DataFilesUtil.classpathSeparator val dependencyLibPath = libPaths.joinToString(separator.toString()) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt similarity index 87% rename from core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt index a0551ed7c..60c4016d4 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestSuiteParser.kt @@ -1,4 +1,4 @@ -package org.jetbrains.research.testspark.core.test.parsers +package org.jetbrains.research.testspark.core.test import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM @@ -11,7 +11,7 @@ data class TestCaseParseResult( interface TestSuiteParser { /** - * Extracts test cases from raw text and generates a test suite using the given package name. + * Extracts test cases from raw text and generates a test suite. * * @param rawText The raw text provided by the LLM that contains the generated test cases. * @return A GeneratedTestSuite instance containing the extracted test cases. diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt index 6e5a4e127..0d9c672de 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt @@ -1,7 +1,6 @@ package org.jetbrains.research.testspark.core.test import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language abstract class TestsAssembler { private var rawText = "" @@ -33,10 +32,9 @@ abstract class TestsAssembler { } /** - * Extracts test cases from raw text and generates a TestSuite using the given package name. + * Extracts test cases from raw text and generates a TestSuite. * - * @param packageName The package name to be set in the generated TestSuite. - * @return A TestSuiteGeneratedByLLM object containing the extracted test cases and package name. + * @return A TestSuiteGeneratedByLLM object containing information about the extracted test cases. */ - abstract fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? + abstract fun assembleTestSuite(): TestSuiteGeneratedByLLM? } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt index 1673fea4a..b9d50132c 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt @@ -4,6 +4,7 @@ package org.jetbrains.research.testspark.core.test * The TestPersistentStorage interface represents a contract for saving generated tests to a specified file system location. */ interface TestsPersistentStorage { + /** * Save the generated tests to a specified directory. * diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/CodeType.kt similarity index 73% rename from src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/CodeType.kt index 8e91aded4..12f18eb54 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/CodeType.kt @@ -1,4 +1,4 @@ -package org.jetbrains.research.testspark.data +package org.jetbrains.research.testspark.core.test.data /** * Enum class, which contains all code elements for which it is possible to request test generation. diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt index 6ef9f6907..2a565e82e 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt @@ -1,5 +1,7 @@ package org.jetbrains.research.testspark.core.test.data +import org.jetbrains.research.testspark.core.test.TestBodyPrinter + /** * * Represents a test case generated by LLM. @@ -11,6 +13,7 @@ data class TestCaseGeneratedByLLM( var expectedException: String = "", var throwsException: String = "", var lines: MutableList = mutableListOf(), + val printTestBodyStrategy: TestBodyPrinter, ) { /** @@ -104,31 +107,7 @@ data class TestCaseGeneratedByLLM( * @return a string containing the body of test case */ private fun printTestBody(testInitiatedText: String): String { - var testFullText = testInitiatedText - - // start writing the test signature - testFullText += "\n\tpublic void $name() " - - // add throws exception if exists - if (throwsException.isNotBlank()) { - testFullText += "throws $throwsException" - } - - // start writing the test lines - testFullText += "{\n" - - // write each line - lines.forEach { line -> - testFullText += when (line.type) { - TestLineType.BREAK -> "\t\t\n" - else -> "\t\t${line.text}\n" - } - } - - // close test case - testFullText += "\t}\n" - - return testFullText + return printTestBodyStrategy.printTestBody(testInitiatedText, lines, throwsException, name) } /** diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt index 211063bb7..4fac9b8b9 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt @@ -4,12 +4,12 @@ package org.jetbrains.research.testspark.core.test.data * Represents a test suite generated by LLM. * * @property imports The set of import statements in the test suite. - * @property packageString The package string of the test suite. + * @property packageName The package name of the test suite. * @property testCases The list of test cases in the test suite. */ data class TestSuiteGeneratedByLLM( var imports: Set = emptySet(), - var packageString: String = "", + var packageName: String = "", var runWith: String = "", var otherInfo: String = "", var testCases: MutableList = mutableListOf(), diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/TestCompilationDependencies.kt similarity index 96% rename from core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/TestCompilationDependencies.kt index 2e78b0b50..622ab0c98 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/TestCompilationDependencies.kt @@ -6,7 +6,7 @@ import org.jetbrains.research.testspark.core.data.JarLibraryDescriptor * The class represents a list of dependencies required for java test compilation. * The libraries listed are used during test suite/test case compilation. */ -class JavaTestCompilationDependencies { +class TestCompilationDependencies { companion object { fun getJarDescriptors() = listOf( JarLibraryDescriptor( diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt new file mode 100644 index 000000000..279badc57 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt @@ -0,0 +1,32 @@ +package org.jetbrains.research.testspark.core.test.java + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.TestSuiteParser +import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy +import org.jetbrains.research.testspark.core.utils.javaImportPattern + +class JavaJUnitTestSuiteParser( + private var packageName: String, + private val junitVersion: JUnitVersion, + private val testBodyPrinter: TestBodyPrinter, +) : TestSuiteParser { + override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { + val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Java) + if (packageInsideTestText.isNotBlank()) { + packageName = packageInsideTestText + } + + return JUnitTestSuiteParserStrategy.parseJUnitTestSuite( + rawText, + junitVersion, + javaImportPattern, + packageName, + testNamePattern = "void", + testBodyPrinter, + ) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt new file mode 100644 index 000000000..bafbcaf13 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestBodyPrinter.kt @@ -0,0 +1,40 @@ +package org.jetbrains.research.testspark.core.test.java + +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.data.TestLine +import org.jetbrains.research.testspark.core.test.data.TestLineType + +class JavaTestBodyPrinter : TestBodyPrinter { + override fun printTestBody( + testInitiatedText: String, + lines: MutableList, + throwsException: String, + name: String, + ): String { + var testFullText = testInitiatedText + + // start writing the test signature + testFullText += "\n\tpublic void $name() " + + // add throws exception if exists + if (throwsException.isNotBlank()) { + testFullText += "throws $throwsException" + } + + // start writing the test lines + testFullText += "{\n" + + // write each line + lines.forEach { line -> + testFullText += when (line.type) { + TestLineType.BREAK -> "\t\t\n" + else -> "\t\t${line.text}\n" + } + } + + // close test case + testFullText += "\t}\n" + + return testFullText + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt new file mode 100644 index 000000000..98f0a3d0c --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaTestCompiler.kt @@ -0,0 +1,53 @@ +package org.jetbrains.research.testspark.core.test.java + +import io.github.oshai.kotlinlogging.KotlinLogging +import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.utils.CommandLineRunner +import org.jetbrains.research.testspark.core.utils.DataFilesUtil +import java.io.File + +class JavaTestCompiler( + libPaths: List, + junitLibPaths: List, + private val javaHomeDirectoryPath: String, +) : TestCompiler(libPaths, junitLibPaths) { + + private val log = KotlinLogging.logger { this::class.java } + + override fun compileCode(path: String, projectBuildPath: String): Pair { + val classPaths = "\"${getClassPaths(projectBuildPath)}\"" + // find the proper javac + val javaCompile = File(javaHomeDirectoryPath).walk() + .filter { + val isCompilerName = + if (DataFilesUtil.isWindows()) it.name.equals("javac.exe") else it.name.equals("javac") + isCompilerName && it.isFile + } + .firstOrNull() + + if (javaCompile == null) { + val msg = "Cannot find java compiler 'javac' at '$javaHomeDirectoryPath'" + log.error { msg } + throw RuntimeException(msg) + } + + println("javac found at '${javaCompile.absolutePath}'") + + // compile file + val errorMsg = CommandLineRunner.run( + arrayListOf( + javaCompile.absolutePath, + "-cp", + classPaths, + path, + ), + ) + + log.info { "Error message: '$errorMsg'" } + // create .class file path + val classFilePath = path.replace(".java", ".class") + + // check is .class file exists + return Pair(File(classFilePath).exists(), errorMsg) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt new file mode 100644 index 000000000..18b164810 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt @@ -0,0 +1,32 @@ +package org.jetbrains.research.testspark.core.test.kotlin + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.TestSuiteParser +import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.test.strategies.JUnitTestSuiteParserStrategy +import org.jetbrains.research.testspark.core.utils.kotlinImportPattern + +class KotlinJUnitTestSuiteParser( + private var packageName: String, + private val junitVersion: JUnitVersion, + private val testBodyPrinter: TestBodyPrinter, +) : TestSuiteParser { + override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { + val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Kotlin) + if (packageInsideTestText.isNotBlank()) { + packageName = packageInsideTestText + } + + return JUnitTestSuiteParserStrategy.parseJUnitTestSuite( + rawText, + junitVersion, + kotlinImportPattern, + packageName, + testNamePattern = "fun", + testBodyPrinter, + ) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt new file mode 100644 index 000000000..a1a9dc8df --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestBodyPrinter.kt @@ -0,0 +1,40 @@ +package org.jetbrains.research.testspark.core.test.kotlin + +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.data.TestLine +import org.jetbrains.research.testspark.core.test.data.TestLineType + +class KotlinTestBodyPrinter : TestBodyPrinter { + override fun printTestBody( + testInitiatedText: String, + lines: MutableList, + throwsException: String, + name: String, + ): String { + var testFullText = testInitiatedText + + // start writing the test signature + testFullText += "\n\tfun $name() " + + // add throws exception if exists + if (throwsException.isNotBlank()) { + testFullText += "throws $throwsException" + } + + // start writing the test lines + testFullText += "{\n" + + // write each line + lines.forEach { line -> + testFullText += when (line.type) { + TestLineType.BREAK -> "\t\t\n" + else -> "\t\t${line.text}\n" + } + } + + // close test case + testFullText += "\t}\n" + + return testFullText + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt new file mode 100644 index 000000000..8d61ce68e --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt @@ -0,0 +1,31 @@ +package org.jetbrains.research.testspark.core.test.kotlin + +import io.github.oshai.kotlinlogging.KotlinLogging +import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.utils.CommandLineRunner + +class KotlinTestCompiler(libPaths: List, junitLibPaths: List) : + TestCompiler(libPaths, junitLibPaths) { + + private val log = KotlinLogging.logger { this::class.java } + + override fun compileCode(path: String, projectBuildPath: String): Pair { + log.info { "[KotlinTestCompiler] Compiling ${path.substringAfterLast('/')}" } + + val classPaths = "\"${getClassPaths(projectBuildPath)}\"" + // Compile file + val errorMsg = CommandLineRunner.run( + arrayListOf( + "kotlinc", + "-cp", + classPaths, + path, + ), + ) + + log.info { "Error message: '$errorMsg'" } + + // No need to save the .class file for kotlin, so checking the error message is enough + return Pair(errorMsg.isBlank(), errorMsg) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt deleted file mode 100644 index a8728bbf2..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt +++ /dev/null @@ -1,22 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers.java - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.strategies.JUnitTestSuiteParserStrategy - -class JavaJUnitTestSuiteParser( - private val packageName: String, - private val junitVersion: JUnitVersion, - private val importPattern: Regex, -) : TestSuiteParser { - override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { - return JUnitTestSuiteParserStrategy.parseTestSuite( - rawText, - junitVersion, - importPattern, - packageName, - testNamePattern = "void", - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt deleted file mode 100644 index 09bdbc627..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt +++ /dev/null @@ -1,22 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers.kotlin - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.strategies.JUnitTestSuiteParserStrategy - -class KotlinJUnitTestSuiteParser( - private val packageName: String, - private val junitVersion: JUnitVersion, - private val importPattern: Regex, -) : TestSuiteParser { - override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { - return JUnitTestSuiteParserStrategy.parseTestSuite( - rawText, - junitVersion, - importPattern, - packageName, - testNamePattern = "fun", - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt deleted file mode 100644 index 98c6827c5..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt +++ /dev/null @@ -1,173 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers.strategies - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM -import org.jetbrains.research.testspark.core.test.data.TestLine -import org.jetbrains.research.testspark.core.test.data.TestLineType -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestCaseParseResult - -class JUnitTestSuiteParserStrategy { - companion object { - fun parseTestSuite( - rawText: String, - junitVersion: JUnitVersion, - importPattern: Regex, - packageName: String, - testNamePattern: String, - ): TestSuiteGeneratedByLLM? { - if (rawText.isBlank()) { - return null - } - - try { - var rawCode = rawText - - if (rawText.contains("```")) { - rawCode = rawText.split("```")[1] - } - - // save imports - val imports = importPattern.findAll(rawCode, 0) - .map { it.groupValues[0] } - .toSet() - - // save RunWith - val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" - - val testSet: MutableList = rawCode.split("@Test").toMutableList() - - // save annotations and pre-set methods - val otherInfo: String = run { - val otherInfoList = testSet.removeAt(0).split("{").toMutableList() - otherInfoList.removeFirst() - val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" - otherInfo.ifBlank { "" } - } - - // Save the main test cases - val testCases: MutableList = mutableListOf() - val testCaseParser = JUnitTestCaseParser() - - testSet.forEach ca@{ - val rawTest = "@Test$it" - - val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) - val result: TestCaseParseResult = - testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern) // /// - - if (result.errorOccurred) { - println("WARNING: ${result.errorMessage}") - return@ca - } - - val currentTest = result.testCase!! - - // TODO: make logging work - // log.info("New test case: $currentTest") - println("New test case: $currentTest") - - testCases.add(currentTest) - } - - val testSuite = TestSuiteGeneratedByLLM( - imports = imports, - packageString = packageName, - runWith = runWith, - otherInfo = otherInfo, - testCases = testCases, - ) - - return testSuite - } catch (e: Exception) { - return null - } - } - } -} - -private class JUnitTestCaseParser { - fun parse(rawTest: String, isLastTestCaseInTestSuite: Boolean, testNamePattern: String): TestCaseParseResult { - var expectedException = "" - var throwsException = "" - val testLines: MutableList = mutableListOf() - - // Get expected Exception - if (rawTest.startsWith("@Test(expected =")) { - expectedException = rawTest.split(")")[0].trim() - } - - // Get unexpected exceptions - /* Each test case should follow fun {...} - Tests do not return anything so it is safe to consider that void always appears before test case name - */ - val voidString = testNamePattern - if (!rawTest.contains(voidString)) { - return TestCaseParseResult( - testCase = null, - errorMessage = "The raw Test does not contain $voidString:\n $rawTest", - errorOccurred = true, - ) - } - val interestingPartOfSignature = rawTest.split(voidString)[1] - .split("{")[0] - .split("()")[1] - .trim() - - if (interestingPartOfSignature.contains("throws")) { - throwsException = interestingPartOfSignature.split("throws")[1].trim() - } - - // Get test name - val testName: String = rawTest.split(voidString)[1] - .split("()")[0] - .trim() - - // Get test body and remove opening bracket - var testBody = rawTest.split("{").toMutableList().apply { removeFirst() } - .joinToString("{").trim() - - // remove closing bracket - val tempList = testBody.split("}").toMutableList() - tempList.removeLast() - - if (isLastTestCaseInTestSuite) { - // it is the last test, thus we should remove another closing bracket - if (tempList.isNotEmpty()) { - tempList.removeLast() - } else { - println("WARNING: the final test does not have the enclosing bracket:\n $testBody") - } - } - - testBody = tempList.joinToString("}") - - // Save each line - val rawLines = testBody.split("\n").toMutableList() - rawLines.forEach { rawLine -> - val line = rawLine.trim() - - val type: TestLineType = when { - line.startsWith("//") -> TestLineType.COMMENT - line.isBlank() -> TestLineType.BREAK - line.lowercase().startsWith("assert") -> TestLineType.ASSERTION - else -> TestLineType.CODE - } - - testLines.add(TestLine(type, line)) - } - - val currentTest = TestCaseGeneratedByLLM( - name = testName, - expectedException = expectedException, - throwsException = throwsException, - lines = testLines, - ) - - return TestCaseParseResult( - testCase = currentTest, - errorMessage = "", - errorOccurred = false, - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt new file mode 100644 index 000000000..7bc818cd0 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt @@ -0,0 +1,175 @@ +package org.jetbrains.research.testspark.core.test.strategies + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.TestCaseParseResult +import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM +import org.jetbrains.research.testspark.core.test.data.TestLine +import org.jetbrains.research.testspark.core.test.data.TestLineType +import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM + +class JUnitTestSuiteParserStrategy { + companion object { + fun parseJUnitTestSuite( + rawText: String, + junitVersion: JUnitVersion, + importPattern: Regex, + packageName: String, + testNamePattern: String, + printTestBodyStrategy: TestBodyPrinter, + ): TestSuiteGeneratedByLLM? { + if (rawText.isBlank()) { + return null + } + + try { + val rawCode = if (rawText.contains("```")) rawText.split("```")[1] else rawText + + // save imports + val imports = importPattern.findAll(rawCode) + .map { it.groupValues[0] } + .toSet() + + // save RunWith + val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" + + val testSet: MutableList = rawCode.split("@Test").toMutableList() + + // save annotations and pre-set methods + val otherInfo: String = run { + val otherInfoList = testSet.removeAt(0).split("{").toMutableList() + otherInfoList.removeFirst() + val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" + otherInfo.ifBlank { "" } + } + + // Save the main test cases + val testCases: MutableList = mutableListOf() + val testCaseParser = JUnitTestCaseParser() + + testSet.forEach ca@{ + val rawTest = "@Test$it" + + val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) + val result: TestCaseParseResult = + testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern, printTestBodyStrategy) + + if (result.errorOccurred) { + println("WARNING: ${result.errorMessage}") + return@ca + } + + val currentTest = result.testCase!! + + // TODO: make logging work + // log.info("New test case: $currentTest") + + testCases.add(currentTest) + } + + val testSuite = TestSuiteGeneratedByLLM( + imports = imports, + packageName = packageName, + runWith = runWith, + otherInfo = otherInfo, + testCases = testCases, + ) + + return testSuite + } catch (e: Exception) { + return null + } + } + } + + private class JUnitTestCaseParser { + fun parse( + rawTest: String, + isLastTestCaseInTestSuite: Boolean, + testNamePattern: String, + printTestBodyStrategy: TestBodyPrinter, + ): TestCaseParseResult { + var expectedException = "" + var throwsException = "" + val testLines: MutableList = mutableListOf() + + // Get expected Exception + if (rawTest.startsWith("@Test(expected =")) { + expectedException = rawTest.split(")")[0].trim() + } + + // Get unexpected exceptions + /* Each test case should follow fun {...} + Tests do not return anything so it is safe to consider that void always appears before test case name + */ + if (!rawTest.contains(testNamePattern)) { + return TestCaseParseResult( + testCase = null, + errorMessage = "The raw Test does not contain $testNamePattern:\n $rawTest", + errorOccurred = true, + ) + } + val interestingPartOfSignature = rawTest.split(testNamePattern)[1] + .split("{")[0] + .split("()")[1] + .trim() + + if (interestingPartOfSignature.contains("throws")) { + throwsException = interestingPartOfSignature.split("throws")[1].trim() + } + + // Get test name + val testName: String = rawTest.split(testNamePattern)[1] + .split("()")[0] + .trim() + + // Get test body and remove opening bracket + var testBody = rawTest.split("{").toMutableList().apply { removeFirst() } + .joinToString("{").trim() + + // remove closing bracket + val tempList = testBody.split("}").toMutableList() + tempList.removeLast() + + if (isLastTestCaseInTestSuite) { + // it is the last test, thus we should remove another closing bracket + if (tempList.isNotEmpty()) { + tempList.removeLast() + } else { + println("WARNING: the final test does not have the enclosing bracket:\n $testBody") + } + } + + testBody = tempList.joinToString("}") + + // Save each line + val rawLines = testBody.split("\n").toMutableList() + rawLines.forEach { rawLine -> + val line = rawLine.trim() + + val type: TestLineType = when { + line.startsWith("//") -> TestLineType.COMMENT + line.isBlank() -> TestLineType.BREAK + line.lowercase().startsWith("assert") -> TestLineType.ASSERTION + else -> TestLineType.CODE + } + + testLines.add(TestLine(type, line)) + } + + val currentTest = TestCaseGeneratedByLLM( + name = testName, + expectedException = expectedException, + throwsException = throwsException, + lines = testLines, + printTestBodyStrategy = printTestBodyStrategy, + ) + + return TestCaseParseResult( + testCase = currentTest, + errorMessage = "", + errorOccurred = false, + ) + } + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt deleted file mode 100644 index 250ec7cba..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt +++ /dev/null @@ -1,8 +0,0 @@ -package org.jetbrains.research.testspark.core.utils - -/** - * Language ID string should be the same as the language name in com.intellij.lang.Language - */ -enum class Language(val languageId: String) { - Java("JAVA"), Kotlin("Kotlin") -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt index 95903bf8c..fb1da6841 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt @@ -6,9 +6,17 @@ val javaImportPattern = options = setOf(RegexOption.MULTILINE), ) +/** + * Parse all the possible Kotlin import patterns + * + * import org.mockito.Mockito.`when` + * import kotlin.math.cos + * import kotlin.math.* + * import kotlin.math.PI as piValue + */ val kotlinImportPattern = Regex( - pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?", + pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?(`\\w*`)?", options = setOf(RegexOption.MULTILINE), ) diff --git a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt index 2ebcde0c9..63fbd0abc 100644 --- a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt +++ b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt @@ -2,14 +2,17 @@ package org.jetbrains.research.testspark.core.test.parsers.kotlin import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.kotlinImportPattern +import org.jetbrains.research.testspark.core.test.kotlin.KotlinJUnitTestSuiteParser +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestBodyPrinter +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test -import kotlin.test.assertNotNull class KotlinJUnitTestSuiteParserTest { @Test - fun testFunction() { + fun testParseTestSuite() { val text = """ ```kotlin import org.junit.jupiter.api.Assertions.* @@ -109,17 +112,149 @@ class KotlinJUnitTestSuiteParserTest { } ``` """.trimIndent() - val parser = KotlinJUnitTestSuiteParser("org.my.package", JUnitVersion.JUnit5, kotlinImportPattern) + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assertTrue(testSuite!!.imports.contains("import org.mockito.Mockito.*")) + assertTrue(testSuite.imports.contains("import org.test.Message as TestMessage")) + assertTrue(testSuite.imports.contains("import org.mockito.kotlin.mock")) + + val expectedTestCasesNames = listOf( + "compileTestCases_AllCompilableTest", + "compileTestCases_NoneCompilableTest", + "compileTestCases_SomeCompilableTest", + "compileTestCases_EmptyTestCasesTest", + "compileTestCases_omg", + ) + + testSuite.testCases.forEachIndexed { index, testCase -> + val expected = expectedTestCasesNames[index] + assertEquals(expected, testCase.name) { "${index + 1}st test case has incorrect name" } + } + + assertTrue(testSuite.testCases[4].expectedException.isNotBlank()) + } + + @Test + fun testParseEmptyTestSuite() { + val text = """ + ```kotlin + package com.example.testsuite + + class EmptyTestClass { + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assertEquals(testSuite!!.packageName, "com.example.testsuite") + assertTrue(testSuite.testCases.isEmpty()) + } + + @Test + fun testParseSingleTestCase() { + val text = """ + ```kotlin + import org.junit.jupiter.api.Test + + class SingleTestCaseClass { + @Test + fun singleTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) assertNotNull(testSuite) - assert(testSuite.imports.contains("import org.mockito.Mockito.*")) - assert(testSuite.imports.contains("import org.test.Message as TestMessage")) - assert(testSuite.imports.contains("import org.mockito.kotlin.mock")) - assert(testSuite.testCases[0].name == "compileTestCases_AllCompilableTest") - assert(testSuite.testCases[1].name == "compileTestCases_NoneCompilableTest") - assert(testSuite.testCases[2].name == "compileTestCases_SomeCompilableTest") - assert(testSuite.testCases[3].name == "compileTestCases_EmptyTestCasesTest") - assert(testSuite.testCases[4].name == "compileTestCases_omg") - assert(testSuite.testCases[4].expectedException.isNotBlank()) + assertEquals(1, testSuite!!.testCases.size) + assertEquals("singleTestCase", testSuite.testCases[0].name) + } + + @Test + fun testParseTwoTestCases() { + val text = """ + ```kotlin + import org.junit.jupiter.api.Test + + class TwoTestCasesClass { + @Test + fun firstTestCase() { + // Test case implementation + } + + @Test + fun secondTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assertEquals(2, testSuite!!.testCases.size) + assertEquals("firstTestCase", testSuite.testCases[0].name) + assertEquals("secondTestCase", testSuite.testCases[1].name) + } + + @Test + fun testParseTwoTestCasesWithDifferentPackage() { + val code1 = """ + ```kotlin + package org.pkg1 + + import org.junit.jupiter.api.Test + + class TestCasesClass1 { + @Test + fun firstTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val code2 = """ + ```kotlin + package org.pkg2 + + import org.junit.jupiter.api.Test + + class 2TestCasesClass { + @Test + fun firstTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter) + + // packageName will be set to 'org.pkg1' + val testSuite1 = parser.parseTestSuite(code1) + + val testSuite2 = parser.parseTestSuite(code2) + + assertNotNull(testSuite1) + assertNotNull(testSuite2) + assertEquals("org.pkg1", testSuite1!!.packageName) + assertEquals("org.pkg2", testSuite2!!.packageName) } } diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt index 007bdbff7..087485827 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt @@ -14,6 +14,7 @@ import org.jetbrains.research.testspark.core.utils.javaImportPattern import org.jetbrains.research.testspark.core.utils.javaPackagePattern import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper +import org.jetbrains.research.testspark.langwrappers.strategies.JavaKotlinClassTextExtractor class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { override val name: String get() = psiClass.name ?: "" @@ -33,29 +34,12 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { override val containingFile: PsiFile get() = psiClass.containingFile override val fullText: String - get() { - var fullText = "" - val fileText = psiClass.containingFile.text - - // get package - javaPackagePattern.findAll(fileText).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n\n" - } - - // get imports - javaImportPattern.findAll(fileText).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n" - } - - // Add class code - fullText += psiClass.text - - return fullText - } + get() = JavaKotlinClassTextExtractor().extract( + psiClass.containingFile, + psiClass.text, + javaPackagePattern, + javaImportPattern, + ) override val classType: ClassType get() { @@ -68,6 +52,8 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { return ClassType.CLASS } + override val rBrace: Int? = psiClass.rBrace?.textRange?.startOffset + override fun searchSubclasses(project: Project): Collection { val scope = GlobalSearchScope.projectScope(project) val query = ClassInheritorsSearch.search(psiClass, scope, false) diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt index 8b513deda..f6f132a29 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt @@ -4,23 +4,27 @@ import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.editor.Caret +import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.project.Project import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiClass import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile +import com.intellij.psi.PsiJavaFile import com.intellij.psi.PsiMethod import com.intellij.psi.util.PsiTreeUtil import com.intellij.psi.util.PsiTypesUtil -import org.jetbrains.research.testspark.langwrappers.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType +import org.jetbrains.research.testspark.langwrappers.CodeTypeDisplayName import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { - override val language: Language get() = Language.Java + override val language: SupportedLanguage get() = SupportedLanguage.Java private val log = Logger.getInstance(this::class.java) @@ -63,7 +67,7 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { return null } - override fun getSurroundingLine(caretOffset: Int): Int? { + override fun getSurroundingLineNumber(caretOffset: Int): Int? { val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null val selectedLine = doc.getLineNumber(caretOffset) @@ -84,7 +88,7 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { project: Project, classesToTest: MutableList, caretOffset: Int, - maxPolymorphismDepth: Int, // check if cut has any non-java super class + maxPolymorphismDepth: Int, ) { val cutPsiClass = getSurroundingClass(caretOffset)!! var currentPsiClass = cutPsiClass @@ -138,39 +142,44 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { } override fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet { - val interestingPsiClasses = cut.getInterestingPsiClassesWithQualifiedNames(psiMethod) + // The cut is always not null for Java, because all functions are always inside the class + val interestingPsiClasses = cut!!.getInterestingPsiClassesWithQualifiedNames(psiMethod) log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}") return interestingPsiClasses } - override fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? { - val result: ArrayList = arrayListOf() + override fun getCurrentListOfCodeTypes(e: AnActionEvent): List { + val result: ArrayList = arrayListOf() val caret: Caret = - e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result.toArray() + e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result val javaPsiClassWrapped = getSurroundingClass(caret.offset) as JavaPsiClassWrapper? val javaPsiMethodWrapped = getSurroundingMethod(caret.offset) as JavaPsiMethodWrapper? - val line: Int? = getSurroundingLine(caret.offset) - - javaPsiClassWrapped?.let { result.add(getClassHTMLDisplayName(it)) } - javaPsiMethodWrapped?.let { result.add(getMethodHTMLDisplayName(it)) } - line?.let { result.add(getLineHTMLDisplayName(it)) } - - if (javaPsiClassWrapped != null && javaPsiMethodWrapped != null) { - log.info( - "The test can be generated for: \n " + - " 1) Class ${javaPsiClassWrapped.qualifiedName} \n" + - " 2) Method ${javaPsiMethodWrapped.name} \n" + - " 3) Line $line", - ) - } + val line: Int? = getSurroundingLineNumber(caret.offset) + + javaPsiClassWrapped?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } + javaPsiMethodWrapped?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } + line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) } + + log.info( + "The test can be generated for: \n " + + " 1) Class ${javaPsiClassWrapped?.qualifiedName ?: "no class"} \n" + + " 2) Method ${javaPsiMethodWrapped?.name ?: "no method"} \n" + + " 3) Line $line", + ) - return result.toArray() + return result } + override fun getPackageName() = (psiFile as PsiJavaFile).packageName + + override fun getModuleFromPsiFile() = ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!! + + override fun getDocumentFromPsiFile() = psiFile.fileDocument + override fun getLineHTMLDisplayName(line: Int) = "line $line" override fun getClassHTMLDisplayName(psiClass: PsiClassWrapper): String = diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt index 8ac75755c..50cc12f0f 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt @@ -21,6 +21,7 @@ import org.jetbrains.research.testspark.core.utils.kotlinImportPattern import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper +import org.jetbrains.research.testspark.langwrappers.strategies.JavaKotlinClassTextExtractor class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWrapper { override val name: String get() = psiClass.name ?: "" @@ -61,29 +62,12 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra override val containingFile: PsiFile get() = psiClass.containingFile override val fullText: String - get() { - var fullText = "" - val fileText = psiClass.containingFile.text - - // get package - kotlinPackagePattern.findAll(fileText, 0).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n\n" - } - - // get imports - kotlinImportPattern.findAll(fileText, 0).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n" - } - - // Add class code - fullText += psiClass.text - - return fullText - } + get() = JavaKotlinClassTextExtractor().extract( + psiClass.containingFile, + psiClass.text, + kotlinPackagePattern, + kotlinImportPattern, + ) override val classType: ClassType get() { @@ -97,6 +81,8 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra } } + override val rBrace: Int? = psiClass.body?.rBrace?.textRange?.startOffset + override fun searchSubclasses(project: Project): Collection { val scope = GlobalSearchScope.projectScope(project) val lightClass = psiClass.toLightClass() @@ -116,11 +102,9 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra method.psiFunction.valueParameters.forEach { parameter -> val typeReference = parameter.typeReference - if (typeReference != null) { - val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) - if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { - interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) - } + val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) + if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { + interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) } } diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt index 13749bd35..8c209b9fd 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt @@ -4,6 +4,7 @@ import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.editor.Caret +import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.project.Project import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiClass @@ -14,21 +15,20 @@ import org.jetbrains.kotlin.asJava.toLightClass import org.jetbrains.kotlin.descriptors.ClassDescriptor import org.jetbrains.kotlin.idea.base.psi.kotlinFqName import org.jetbrains.kotlin.idea.caches.resolve.analyze -import org.jetbrains.kotlin.psi.KtClass -import org.jetbrains.kotlin.psi.KtClassOrObject -import org.jetbrains.kotlin.psi.KtFunction -import org.jetbrains.kotlin.psi.KtTypeReference +import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils import org.jetbrains.kotlin.resolve.lazy.BodyResolveMode -import org.jetbrains.research.testspark.langwrappers.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType +import org.jetbrains.research.testspark.langwrappers.CodeTypeDisplayName import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper -class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { +class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { - override val language: Language get() = Language.Kotlin + override val language: SupportedLanguage get() = SupportedLanguage.Kotlin private val log = Logger.getInstance(this::class.java) @@ -66,7 +66,7 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { return null } - override fun getSurroundingLine(caretOffset: Int): Int? { + override fun getSurroundingLineNumber(caretOffset: Int): Int? { val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null val selectedLine = doc.getLineNumber(caretOffset) @@ -85,9 +85,10 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { project: Project, classesToTest: MutableList, caretOffset: Int, - maxPolymorphismDepth: Int, // check if cut has any non-java super class + maxPolymorphismDepth: Int, ) { - val cutPsiClass = getSurroundingClass(caretOffset)!! + val cutPsiClass = getSurroundingClass(caretOffset) ?: return + // will be null for the top level function var currentPsiClass = cutPsiClass for (index in 0 until maxPolymorphismDepth) { if (!classesToTest.contains(currentPsiClass)) { @@ -116,19 +117,13 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { repeat(maxInputParamsDepth) { val tempListOfClasses = mutableSetOf() - currentLevelClasses.forEach { classIt -> classIt.methods.forEach { methodIt -> (methodIt as KotlinPsiMethodWrapper).parameterList?.parameters?.forEach { paramIt -> - val typeRef = paramIt.typeReference - if (typeRef != null) { - resolveClassInType(typeRef)?.let { psiClass -> - if (psiClass.kotlinFqName != null) { - KotlinPsiClassWrapper(psiClass as KtClass).let { - if (!it.qualifiedName.startsWith("kotlin.")) { - interestingPsiClasses.add(it) - } - } + KtPsiUtil.getClassIfParameterIsProperty(paramIt)?.let { typeIt -> + KotlinPsiClassWrapper(typeIt).let { + if (!it.qualifiedName.startsWith("kotlin.")) { + interestingPsiClasses.add(it) } } } @@ -143,39 +138,45 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { } override fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet { - val interestingPsiClasses = cut.getInterestingPsiClassesWithQualifiedNames(psiMethod) + val interestingPsiClasses = + cut?.getInterestingPsiClassesWithQualifiedNames(psiMethod) + ?: (psiMethod as KotlinPsiMethodWrapper).getInterestingPsiClassesWithQualifiedNames() log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}") return interestingPsiClasses } - override fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? { - val result: ArrayList = arrayListOf() + override fun getCurrentListOfCodeTypes(e: AnActionEvent): List { + val result: ArrayList = arrayListOf() val caret: Caret = - e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result.toArray() + e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result val ktClass = getSurroundingClass(caret.offset) val ktFunction = getSurroundingMethod(caret.offset) - val line: Int? = getSurroundingLine(caret.offset)?.plus(1) - - ktClass?.let { result.add(getClassHTMLDisplayName(it)) } - ktFunction?.let { result.add(getMethodHTMLDisplayName(it)) } - line?.let { result.add(getLineHTMLDisplayName(it)) } - - if (ktClass != null && ktFunction != null) { - log.info( - "The test can be generated for: \n " + - " 1) Class ${ktClass.qualifiedName} \n" + - " 2) Method ${ktFunction.name} \n" + - " 3) Line $line", - ) - } + val line: Int? = getSurroundingLineNumber(caret.offset)?.plus(1) + + ktClass?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } + ktFunction?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } + line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) } - return result.toArray() + log.info( + "The test can be generated for: \n " + + " 1) Class ${ktClass?.qualifiedName ?: "no class"} \n" + + " 2) Method ${ktFunction?.name ?: "no method"} \n" + + " 3) Line $line", + ) + + return result } + override fun getPackageName() = (psiFile as KtFile).packageFqName.asString() + + override fun getModuleFromPsiFile() = ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!! + + override fun getDocumentFromPsiFile() = psiFile.fileDocument + override fun getLineHTMLDisplayName(line: Int) = "line $line" override fun getClassHTMLDisplayName(psiClass: PsiClassWrapper): String = @@ -184,18 +185,11 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { override fun getMethodHTMLDisplayName(psiMethod: PsiMethodWrapper): String { psiMethod as KotlinPsiMethodWrapper return when { - psiMethod.isTopLevelFunction -> "top-level function" + psiMethod.isTopLevelFunction -> "top-level function ${psiMethod.name}" psiMethod.isSecondaryConstructor -> "secondary constructor" psiMethod.isPrimaryConstructor -> "constructor" psiMethod.isDefaultMethod -> "default method ${psiMethod.name}" else -> "method ${psiMethod.name}" } } - - private fun resolveClassInType(typeReference: KtTypeReference): PsiClass? { - val context = typeReference.analyze(BodyResolveMode.PARTIAL) - val type = context[BindingContext.TYPE, typeReference] ?: return null - val classDescriptor = type.constructor.declarationDescriptor as? ClassDescriptor ?: return null - return (DescriptorToSourceUtils.getSourceFromDescriptor(classDescriptor) as? KtClass)?.toLightClass() - } } diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt index a142aaaa8..c993fd808 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt @@ -68,6 +68,26 @@ class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper { return lineNumber in startLine..endLine } + /** + * Returns a set of `PsiClassWrapper` instances for non-standard Kotlin classes referenced by the + * parameters of the current function. + * + * @return A mutable set of `PsiClassWrapper` instances representing non-standard Kotlin classes. + */ + fun getInterestingPsiClassesWithQualifiedNames(): MutableSet { + val interestingPsiClasses = mutableSetOf() + + psiFunction.valueParameters.forEach { parameter -> + val typeReference = parameter.typeReference + val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) + if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { + interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) + } + } + + return interestingPsiClasses + } + /** * Generates the return descriptor for a method. * diff --git a/langwrappers/build.gradle.kts b/langwrappers/build.gradle.kts index 74ec82496..317debb35 100644 --- a/langwrappers/build.gradle.kts +++ b/langwrappers/build.gradle.kts @@ -5,7 +5,6 @@ plugins { repositories { mavenCentral() - // Add any other repositories you need } dependencies { @@ -17,7 +16,6 @@ dependencies { intellij { rootProject.properties["platformVersion"]?.let { version.set(it.toString()) } plugins.set(listOf("java")) - downloadSources.set(true) } tasks.named("verifyPlugin") { enabled = false } diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt new file mode 100644 index 000000000..0982b9ced --- /dev/null +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/LanguageClassTextExtractor.kt @@ -0,0 +1,7 @@ +package org.jetbrains.research.testspark.langwrappers + +import com.intellij.psi.PsiFile + +interface LanguageClassTextExtractor { + fun extract(file: PsiFile, classText: String, packagePattern: Regex, importPattern: Regex): String +} diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt index f61dc7a1b..0aa5dfd0f 100644 --- a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt @@ -1,11 +1,15 @@ package org.jetbrains.research.testspark.langwrappers import com.intellij.openapi.actionSystem.AnActionEvent +import com.intellij.openapi.editor.Document import com.intellij.openapi.project.Project import com.intellij.openapi.vfs.VirtualFile import com.intellij.psi.PsiFile import org.jetbrains.research.testspark.core.data.ClassType -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType + +typealias CodeTypeDisplayName = Pair /** * Interface representing a wrapper for PSI methods, @@ -40,12 +44,14 @@ interface PsiMethodWrapper { * @property name The name of a class * @property qualifiedName The qualified name of the class. * @property text The text of the class. - * @property fullText The source code of the class (with package and imports). - * @property virtualFile - * @property containingFile File where the method is located - * @property superClass The super class of the class * @property methods All methods in the class * @property allMethods All methods in the class and all its superclasses + * @property superClass The super class of the class + * @property virtualFile Virtual file where the class is located + * @property containingFile File where the method is located + * @property fullText The source code of the class (with package and imports). + * @property classType The type of the class + * @property rBrace The offset of the closing brace * */ interface PsiClassWrapper { val name: String @@ -58,6 +64,7 @@ interface PsiClassWrapper { val containingFile: PsiFile val fullText: String val classType: ClassType + val rBrace: Int? /** * Searches for subclasses of the current class within the given project. @@ -81,7 +88,7 @@ interface PsiClassWrapper { * handling the PSI (Program Structure Interface) for different languages. */ interface PsiHelper { - val language: Language + val language: SupportedLanguage /** * Returns the surrounding PsiClass object based on the caret position within the specified PsiFile. @@ -107,7 +114,7 @@ interface PsiHelper { * @param caretOffset The caret offset within the PSI file. * @return The line number of the selected line, otherwise null. */ - fun getSurroundingLine(caretOffset: Int): Int? + fun getSurroundingLineNumber(caretOffset: Int): Int? /** * Retrieves a set of interesting PsiClasses based on a given project, @@ -133,7 +140,7 @@ interface PsiHelper { * @return A mutable set of interesting PsiClasses. */ fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet @@ -145,7 +152,7 @@ interface PsiHelper { * The array contains the class display name, method display name (if present), and the line number (if present). * The line number is prefixed with "Line". */ - fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? + fun getCurrentListOfCodeTypes(e: AnActionEvent): List /** * Helper for generating method descriptors for methods. @@ -160,8 +167,8 @@ interface PsiHelper { * * @param project The project in which to collect classes to test. * @param classesToTest The list of classes to test. - * @param psiHelper The PSI helper instance to use for collecting classes. * @param caretOffset The caret offset in the file. + * @param maxPolymorphismDepth Check if cut has any user-defined superclass */ fun collectClassesToTest( project: Project, @@ -170,6 +177,21 @@ interface PsiHelper { maxPolymorphismDepth: Int, ) + /** + * Get the package name of the file. + */ + fun getPackageName(): String + + /** + * Get the module of the file. + */ + fun getModuleFromPsiFile(): com.intellij.openapi.module.Module + + /** + * Get the module of the file. + */ + fun getDocumentFromPsiFile(): Document? + /** * Gets the display line number. * This is used when displaying the name of a method in the GenerateTestsActionMethod menu entry. diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt new file mode 100644 index 000000000..643cdee34 --- /dev/null +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/strategies/JavaKotlinClassTextExtractor.kt @@ -0,0 +1,39 @@ +package org.jetbrains.research.testspark.langwrappers.strategies + +import com.intellij.psi.PsiFile +import org.jetbrains.research.testspark.langwrappers.LanguageClassTextExtractor + +/** +Direct implementor for the Java and Kotlin PsiWrappers + */ +class JavaKotlinClassTextExtractor : LanguageClassTextExtractor { + + override fun extract( + file: PsiFile, + classText: String, + packagePattern: Regex, + importPattern: Regex, + ): String { + var fullText = "" + val fileText = file.text + + // get package + packagePattern.findAll(fileText, 0).map { + it.groupValues[0] + }.forEach { + fullText += "$it\n\n" + } + + // get imports + importPattern.findAll(fileText, 0).map { + it.groupValues[0] + }.forEach { + fullText += "$it\n" + } + + // Add class code + fullText += classText + + return fullText + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt index 3b08ca009..a6f342882 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt @@ -17,6 +17,7 @@ import org.jetbrains.research.testspark.actions.llm.LLMSetupPanelFactory import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.display.TestSparkIcons import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider @@ -76,7 +77,6 @@ class TestSparkAction : AnAction() { if (psiHelper == null) { // TODO exception } - e.presentation.isEnabled = psiHelper!!.getCurrentListOfCodeTypes(e) != null } /** @@ -111,18 +111,18 @@ class TestSparkAction : AnAction() { return psiHelper!! } - private val codeTypes = psiHelper.getCurrentListOfCodeTypes(e)!! + private val codeTypes = psiHelper.getCurrentListOfCodeTypes(e) private val caretOffset: Int = e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret!!.offset private val fileUrl = e.dataContext.getData(CommonDataKeys.VIRTUAL_FILE)!!.presentableUrl - private val codeTypeButtons: MutableList = mutableListOf() + private val codeTypeButtons: MutableList> = mutableListOf() private val codeTypeButtonGroup = ButtonGroup() private val nextButton = JButton(PluginLabelsBundle.get("next")) private val cardLayout = CardLayout() private val llmSetupPanelFactory = LLMSetupPanelFactory(e, project) - private val llmSampleSelectorFactory = LLMSampleSelectorFactory(project) + private val llmSampleSelectorFactory = LLMSampleSelectorFactory(project, psiHelper.language) private val evoSuitePanelFactory = EvoSuitePanelFactory(project) init { @@ -198,16 +198,19 @@ class TestSparkAction : AnAction() { testGeneratorPanel.add(llmButton) testGeneratorPanel.add(evoSuiteButton) - for (codeType in codeTypes) { - val button = JRadioButton(codeType as String) - codeTypeButtons.add(button) + for ((codeType, codeTypeName) in codeTypes) { + val button = JRadioButton(codeTypeName) + codeTypeButtons.add(codeType to button) codeTypeButtonGroup.add(button) } val codesToTestPanel = JPanel() codesToTestPanel.add(JLabel("Select the code type:")) - if (codeTypeButtons.size == 1) codeTypeButtons[0].isSelected = true - for (button in codeTypeButtons) codesToTestPanel.add(button) + if (codeTypeButtons.size == 1) { + // A single button is selected by default + codeTypeButtons[0].second.isSelected = true + } + for ((_, button) in codeTypeButtons) codesToTestPanel.add(button) val middlePanel = FormBuilder.createFormBuilder() .setFormLeftIndent(10) @@ -253,7 +256,7 @@ class TestSparkAction : AnAction() { updateNextButton() } - for (button in codeTypeButtons) { + for ((_, button) in codeTypeButtons) { button.addActionListener { llmSetupPanelFactory.setPromptEditorType(button.text) updateNextButton() @@ -330,33 +333,36 @@ class TestSparkAction : AnAction() { if (!testGenerationController.isGeneratorRunning(project)) { val testSamplesCode = llmSampleSelectorFactory.getTestSamplesCode() - if (codeTypeButtons[0].isSelected) { - tool.generateTestsForClass( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) - } else if (codeTypeButtons[1].isSelected) { - tool.generateTestsForMethod( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) - } else if (codeTypeButtons[2].isSelected) { - tool.generateTestsForLine( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) + for ((codeType, button) in codeTypeButtons) { + if (button.isSelected) { + when (codeType) { + CodeType.CLASS -> tool.generateTestsForClass( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + CodeType.METHOD -> tool.generateTestsForMethod( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + CodeType.LINE -> tool.generateTestsForLine( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + } + break + } } } @@ -376,10 +382,7 @@ class TestSparkAction : AnAction() { */ private fun updateNextButton() { val isTestGeneratorButtonGroupSelected = llmButton.isSelected || evoSuiteButton.isSelected - var isCodeTypeButtonGroupSelected = false - for (button in codeTypeButtons) { - isCodeTypeButtonGroupSelected = isCodeTypeButtonGroupSelected || button.isSelected - } + val isCodeTypeButtonGroupSelected = codeTypeButtons.any { it.second.isSelected } nextButton.isEnabled = isTestGeneratorButtonGroupSelected && isCodeTypeButtonGroupSelected if ((llmButton.isSelected && !llmSettingsState.llmSetupCheckBoxSelected && !llmSettingsState.provideTestSamplesCheckBoxSelected) || @@ -393,4 +396,4 @@ class TestSparkAction : AnAction() { } override fun getActionUpdateThread(): ActionUpdateThread = ActionUpdateThread.BGT -} \ No newline at end of file +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt index b57ee8d81..b6b77a0ff 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt @@ -4,6 +4,7 @@ import com.intellij.openapi.project.Project import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.helpers.LLMTestSampleHelper import java.awt.Font import javax.swing.ButtonGroup @@ -12,7 +13,7 @@ import javax.swing.JLabel import javax.swing.JPanel import javax.swing.JRadioButton -class LLMSampleSelectorFactory(private val project: Project) : PanelFactory { +class LLMSampleSelectorFactory(private val project: Project, private val language: SupportedLanguage) : PanelFactory { // init components private val selectionTypeButtons: MutableList = mutableListOf( JRadioButton(PluginLabelsBundle.get("provideTestSample")), @@ -128,7 +129,7 @@ class LLMSampleSelectorFactory(private val project: Project) : PanelFactory { } addButton.addActionListener { - val testSamplePanelFactory = TestSamplePanelFactory(project, middlePanel, testNames, initialTestCodes) + val testSamplePanelFactory = TestSamplePanelFactory(project, middlePanel, testNames, initialTestCodes, language) testSamplePanelFactories.add(testSamplePanelFactory) val testSamplePanel = testSamplePanelFactory.getTestSamplePanel() val codeScrollPanel = testSamplePanelFactory.getCodeScrollPanel() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt index 69d5db9f3..8afe31fc8 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt @@ -34,7 +34,7 @@ class LLMSetupPanelFactory(e: AnActionEvent, private val project: Project) : Pan private val defaultModulesArray = arrayOf("") private var modelSelector = ComboBox(defaultModulesArray) private var llmUserTokenField = JTextField(30) - private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName)) + private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName)) private val backLlmButton = JButton(PluginLabelsBundle.get("back")) private val okLlmButton = JButton(PluginLabelsBundle.get("next")) private val junitSelector = JUnitCombobox(e) @@ -142,6 +142,10 @@ class LLMSetupPanelFactory(e: AnActionEvent, private val project: Project) : Pan llmSettingsState.grazieToken = llmPlatforms[index].token llmSettingsState.grazieModel = llmPlatforms[index].model } + if (llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + llmSettingsState.huggingFaceToken = llmPlatforms[index].token + llmSettingsState.huggingFaceModel = llmPlatforms[index].model + } } llmSettingsState.junitVersion = junitSelector.selectedItem!! as JUnitVersion diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt index 97cf6d49a..251a45f27 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt @@ -10,6 +10,7 @@ import com.intellij.openapi.ui.ComboBox import com.intellij.ui.LanguageTextField import com.intellij.ui.components.JBScrollPane import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.display.IconButtonCreator import org.jetbrains.research.testspark.display.ModifiedLinesGetter import org.jetbrains.research.testspark.display.TestCaseDocumentCreator @@ -25,11 +26,12 @@ class TestSamplePanelFactory( private val middlePanel: JPanel, private val testNames: MutableList, private val initialTestCodes: MutableList, + private val language: SupportedLanguage, ) { // init components private val currentTestCodes = initialTestCodes.toMutableList() private val languageTextField = LanguageTextField( - Language.findLanguageByID("JAVA"), + Language.findLanguageByID(language.languageId), project, initialTestCodes[0], TestCaseDocumentCreator("TestSample"), diff --git a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt index b8b0654d3..499abf1c1 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt @@ -18,7 +18,8 @@ import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.ProjectContext import org.jetbrains.research.testspark.data.llm.JsonEncoding @@ -26,6 +27,7 @@ import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider import org.jetbrains.research.testspark.progress.HeadlessProgressIndicator import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.services.PluginSettingsService +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.Llm @@ -172,6 +174,12 @@ class TestSparkStarter : ApplicationStarter { // Start test generation val indicator = HeadlessProgressIndicator() val errorMonitor = DefaultErrorMonitor() + val testCompiler = TestCompilerFactory.create( + project, + settingsState.junitVersion, + psiHelper.language, + projectSDKPath.toString(), + ) val uiContext = llmProcessManager.runTestGenerator( indicator, FragmentToTestData(CodeType.CLASS), @@ -192,6 +200,7 @@ class TestSparkStarter : ApplicationStarter { classPath, projectContext, projectSDKPath, + testCompiler, ) } else { println("[TestSpark Starter] Test generation failed") @@ -237,6 +246,7 @@ class TestSparkStarter : ApplicationStarter { classPath: String, projectContext: ProjectContext, projectSDKPath: Path, + testCompiler: TestCompiler, ) { val targetDirectory = "$out${File.separator}${packageList.joinToString(File.separator)}" println("Run tests in $targetDirectory") @@ -246,6 +256,7 @@ class TestSparkStarter : ApplicationStarter { var testcaseName = it.nameWithoutExtension.removePrefix("Generated") testcaseName = testcaseName[0].lowercaseChar() + testcaseName.substring(1) // The current test is compiled and is ready to run jacoco + val testExecutionError = TestProcessor(project, projectSDKPath).createXmlFromJacoco( it.nameWithoutExtension, "$targetDirectory${File.separator}jacoco-${it.nameWithoutExtension}", @@ -254,6 +265,7 @@ class TestSparkStarter : ApplicationStarter { packageList.joinToString("."), out, projectContext, + testCompiler, ) // Saving exception (if exists) thrown during the test execution saveException(testcaseName, targetDirectory, testExecutionError) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt index 0cf79dddb..3c289bb11 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt @@ -1,5 +1,7 @@ package org.jetbrains.research.testspark.data +import org.jetbrains.research.testspark.core.test.data.CodeType + /** * Data about test objects that require test generators. */ diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt index f17e8720b..99b0ec5ab 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt @@ -25,17 +25,20 @@ import org.jetbrains.research.testspark.core.data.Report import org.jetbrains.research.testspark.core.data.TestCase import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.data.llm.JsonEncoding import org.jetbrains.research.testspark.display.custom.IJProgressIndicator -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper import org.jetbrains.research.testspark.helpers.LLMHelper import org.jetbrains.research.testspark.helpers.ReportHelper import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.services.TestCaseDisplayService import org.jetbrains.research.testspark.services.TestsExecutionResultService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestClassCodeAnalyzerFactory +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.test.JUnitTestSuitePresenter @@ -58,7 +61,7 @@ import javax.swing.border.MatteBorder class TestCasePanelFactory( private val project: Project, - private val language: org.jetbrains.research.testspark.core.utils.Language, + private val language: SupportedLanguage, private val testCase: TestCase, editor: Editor, private val checkbox: JCheckBox, @@ -193,7 +196,10 @@ class TestCasePanelFactory( val clipboard: Clipboard = Toolkit.getDefaultToolkit().systemClipboard clipboard.setContents( StringSelection( - project.service().getEditor(testCase.testName)!!.document.text, + when (language) { + SupportedLanguage.Kotlin -> project.service().getEditor(testCase.testName)!!.document.text + SupportedLanguage.Java -> project.service().getEditor(testCase.testName)!!.document.text + }, ), null, ) @@ -386,7 +392,10 @@ class TestCasePanelFactory( } ReportHelper.updateTestCase(project, report, testCase) - project.service().updateUI() + when (language) { + SupportedLanguage.Kotlin -> project.service().updateUI() + SupportedLanguage.Java -> project.service().updateUI() + } } /** @@ -454,12 +463,12 @@ class TestCasePanelFactory( } private fun addTest(testSuite: TestSuiteGeneratedByLLM) { - val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput) + val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput, language) WriteCommandAction.runWriteCommandAction(project) { uiContext.errorMonitor.clear() val code = testSuitePresenter.toString(testSuite) - testCase.testName = JavaClassBuilderHelper.getTestMethodNameFromClassWithTestCase(testCase.testName, code) + testCase.testName = TestClassCodeAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, code) testCase.testCode = code // update numbers @@ -517,15 +526,24 @@ class TestCasePanelFactory( private fun runTest(indicator: CustomProgressIndicator) { indicator.setText("Executing ${testCase.testName}") + val fileName = TestClassCodeAnalyzerFactory.create(language).getFileNameFromTestCaseCode(testCase.testName) + + val testCompiler = TestCompilerFactory.create( + project, + llmSettingsState.junitVersion, + language, + ) + val newTestCase = TestProcessor(project) .processNewTestCase( - "${JavaClassBuilderHelper.getClassFromTestCaseCode(testCase.testCode)}.java", + fileName, testCase.id, testCase.testName, testCase.testCode, - uiContext!!.testGenerationOutput.packageLine, + uiContext!!.testGenerationOutput.packageName, uiContext.testGenerationOutput.resultPath, uiContext.projectContext, + testCompiler, ) testCase.coveredLines = newTestCase.coveredLines @@ -585,13 +603,23 @@ class TestCasePanelFactory( */ private fun remove() { // Remove the test case from the cache - project.service().removeTestCase(testCase.testName) + when (language) { + SupportedLanguage.Kotlin -> project.service().removeTestCase(testCase.testName) + + SupportedLanguage.Java -> project.service().removeTestCase(testCase.testName) + } runTestButton.isEnabled = false isRemoved = true ReportHelper.removeTestCase(project, report, testCase) - project.service().updateUI() + when (language) { + SupportedLanguage.Kotlin -> project.service() + .updateUI() + + SupportedLanguage.Java -> project.service() + .updateUI() + } } /** @@ -663,8 +691,7 @@ class TestCasePanelFactory( * Updates the current test case with the specified test name and test code. */ private fun updateTestCaseInformation() { - testCase.testName = - JavaClassBuilderHelper.getTestMethodNameFromClassWithTestCase(testCase.testName, languageTextField.document.text) + testCase.testName = TestClassCodeAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, languageTextField.document.text) testCase.testCode = languageTextField.document.text } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt index 31cc7b9a6..b8f90918c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt @@ -1,6 +1,5 @@ package org.jetbrains.research.testspark.display -import com.intellij.openapi.components.service import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.progress.Task @@ -8,20 +7,20 @@ import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.display.custom.IJProgressIndicator -import org.jetbrains.research.testspark.services.TestCaseDisplayService +import org.jetbrains.research.testspark.display.strategies.TopButtonsPanelStrategy import java.awt.Dimension import java.util.LinkedList import java.util.Queue import javax.swing.Box import javax.swing.BoxLayout import javax.swing.JButton -import javax.swing.JCheckBox import javax.swing.JLabel import javax.swing.JOptionPane import javax.swing.JPanel -class TopButtonsPanelFactory(private val project: Project) { +class TopButtonsPanelFactory(private val project: Project, private val language: SupportedLanguage) { private var runAllButton: JButton = createRunAllTestButton() private var selectAllButton: JButton = IconButtonCreator.getButton(TestSparkIcons.selectAll, PluginLabelsBundle.get("selectAllTip")) @@ -64,28 +63,26 @@ class TopButtonsPanelFactory(private val project: Project) { * Updates the labels. */ fun updateTopLabels() { - var numberOfPassedTests = 0 - for (testCasePanelFactory in testCasePanelFactories) { - if (testCasePanelFactory.isRemoved()) continue - val error = testCasePanelFactory.getError() - if ((error is String) && error.isEmpty()) { - numberOfPassedTests++ - } - } - testsSelectedLabel.text = String.format( - testsSelectedText, - project.service().getTestsSelected(), - project.service().getTestCasePanels().size, - ) - testsPassedLabel.text = - String.format( + when (language) { + SupportedLanguage.Java -> TopButtonsPanelStrategy.updateTopJavaLabels( + testCasePanelFactories, + testsSelectedLabel, + testsSelectedText, + project, + testsPassedLabel, testsPassedText, - numberOfPassedTests, - project.service().getTestCasePanels().size, + runAllButton, + ) + + SupportedLanguage.Kotlin -> TopButtonsPanelStrategy.updateTopKotlinLabels( + testCasePanelFactories, + testsSelectedLabel, + testsSelectedText, + project, + testsPassedLabel, + testsPassedText, + runAllButton, ) - runAllButton.isEnabled = false - for (testCasePanelFactory in testCasePanelFactories) { - runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled() } } @@ -105,31 +102,20 @@ class TopButtonsPanelFactory(private val project: Project) { * @param selected whether the checkboxes have to be selected or not */ private fun toggleAllCheckboxes(selected: Boolean) { - project.service().getTestCasePanels().forEach { (_, jPanel) -> - val checkBox = jPanel.getComponent(0) as JCheckBox - checkBox.isSelected = selected + when (language) { + SupportedLanguage.Java -> TopButtonsPanelStrategy.toggleAllJavaCheckboxes(selected, project) + SupportedLanguage.Kotlin -> TopButtonsPanelStrategy.toggleAllKotlinCheckboxes(selected, project) } - project.service() - .setTestsSelected(if (selected) project.service().getTestCasePanels().size else 0) } /** * Removes all test cases from the cache and tool window UI. */ private fun removeAllTestCases() { - // Ask the user for the confirmation - val choice = JOptionPane.showConfirmDialog( - null, - PluginMessagesBundle.get("removeAllMessage"), - PluginMessagesBundle.get("confirmationTitle"), - JOptionPane.YES_NO_OPTION, - JOptionPane.QUESTION_MESSAGE, - ) - - // Cancel the operation if the user did not press "Yes" - if (choice == JOptionPane.NO_OPTION) return - - project.service().clear() + when (language) { + SupportedLanguage.Java -> TopButtonsPanelStrategy.removeAllJavaTestCases(project) + SupportedLanguage.Kotlin -> TopButtonsPanelStrategy.removeAllKotlinTestCases(project) + } } /** diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/strategies/TopButtonsPanelStrategy.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/strategies/TopButtonsPanelStrategy.kt new file mode 100644 index 000000000..07d8f88f2 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/strategies/TopButtonsPanelStrategy.kt @@ -0,0 +1,138 @@ +package org.jetbrains.research.testspark.display.strategies + +import com.intellij.openapi.components.service +import com.intellij.openapi.project.Project +import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle +import org.jetbrains.research.testspark.display.TestCasePanelFactory +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService +import javax.swing.JButton +import javax.swing.JCheckBox +import javax.swing.JLabel +import javax.swing.JOptionPane + +class TopButtonsPanelStrategy { + companion object { + fun toggleAllJavaCheckboxes(selected: Boolean, project: Project) { + project.service().getTestCasePanels().forEach { (_, jPanel) -> + val checkBox = jPanel.getComponent(0) as JCheckBox + checkBox.isSelected = selected + } + project.service() + .setTestsSelected( + if (selected) project.service().getTestCasePanels().size else 0, + ) + } + + fun toggleAllKotlinCheckboxes(selected: Boolean, project: Project) { + project.service().getTestCasePanels().forEach { (_, jPanel) -> + val checkBox = jPanel.getComponent(0) as JCheckBox + checkBox.isSelected = selected + } + project.service() + .setTestsSelected( + if (selected) project.service().getTestCasePanels().size else 0, + ) + } + + fun updateTopJavaLabels( + testCasePanelFactories: ArrayList, + testsSelectedLabel: JLabel, + testsSelectedText: String, + project: Project, + testsPassedLabel: JLabel, + testsPassedText: String, + runAllButton: JButton, + ) { + var numberOfPassedTests = 0 + for (testCasePanelFactory in testCasePanelFactories) { + if (testCasePanelFactory.isRemoved()) continue + val error = testCasePanelFactory.getError() + if ((error is String) && error.isEmpty()) { + numberOfPassedTests++ + } + } + testsSelectedLabel.text = String.format( + testsSelectedText, + project.service().getTestsSelected(), + project.service().getTestCasePanels().size, + ) + testsPassedLabel.text = + String.format( + testsPassedText, + numberOfPassedTests, + project.service().getTestCasePanels().size, + ) + runAllButton.isEnabled = false + for (testCasePanelFactory in testCasePanelFactories) { + runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled() + } + } + + fun updateTopKotlinLabels( + testCasePanelFactories: ArrayList, + testsSelectedLabel: JLabel, + testsSelectedText: String, + project: Project, + testsPassedLabel: JLabel, + testsPassedText: String, + runAllButton: JButton, + ) { + var numberOfPassedTests = 0 + for (testCasePanelFactory in testCasePanelFactories) { + if (testCasePanelFactory.isRemoved()) continue + val error = testCasePanelFactory.getError() + if ((error is String) && error.isEmpty()) { + numberOfPassedTests++ + } + } + testsSelectedLabel.text = String.format( + testsSelectedText, + project.service().getTestsSelected(), + project.service().getTestCasePanels().size, + ) + testsPassedLabel.text = + String.format( + testsPassedText, + numberOfPassedTests, + project.service().getTestCasePanels().size, + ) + runAllButton.isEnabled = false + for (testCasePanelFactory in testCasePanelFactories) { + runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled() + } + } + + fun removeAllJavaTestCases(project: Project) { + // Ask the user for the confirmation + val choice = JOptionPane.showConfirmDialog( + null, + PluginMessagesBundle.get("removeAllMessage"), + PluginMessagesBundle.get("confirmationTitle"), + JOptionPane.YES_NO_OPTION, + JOptionPane.QUESTION_MESSAGE, + ) + + // Cancel the operation if the user did not press "Yes" + if (choice == JOptionPane.NO_OPTION) return + + project.service().clear() + } + + fun removeAllKotlinTestCases(project: Project) { + // Ask the user for the confirmation + val choice = JOptionPane.showConfirmDialog( + null, + PluginMessagesBundle.get("removeAllMessage"), + PluginMessagesBundle.get("confirmationTitle"), + JOptionPane.YES_NO_OPTION, + JOptionPane.QUESTION_MESSAGE, + ) + + // Cancel the operation if the user did not press "Yes" + if (choice == JOptionPane.NO_OPTION) return + + project.service().clear() + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt index bcad7a834..dee6a2b0e 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt @@ -16,7 +16,7 @@ import com.intellij.ui.components.JBLabel import com.intellij.ui.components.JBScrollPane import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.services.EvoSuiteSettingsService -import org.jetbrains.research.testspark.services.TestCaseDisplayService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsState import java.awt.Color import java.awt.Dimension @@ -130,7 +130,7 @@ class CoverageHelper( * @param name name of the test to highlight */ private fun highlightInToolwindow(name: String) { - val testCaseDisplayService = project.service() + val testCaseDisplayService = project.service() testCaseDisplayService.highlightTestCase(name) } @@ -141,7 +141,7 @@ class CoverageHelper( * @param map map of mutant operations -> List of names of tests which cover the mutants */ private fun highlightMutantsInToolwindow(mutantName: String, map: HashMap>) { - val testCaseDisplayService = project.service() + val testCaseDisplayService = project.service() testCaseDisplayService.highlightCoveredMutants(map.getOrPut(mutantName) { ArrayList() }) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt deleted file mode 100644 index 977873bdb..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt +++ /dev/null @@ -1,204 +0,0 @@ -package org.jetbrains.research.testspark.helpers - -import com.github.javaparser.ParseProblemException -import com.github.javaparser.StaticJavaParser -import com.github.javaparser.ast.CompilationUnit -import com.github.javaparser.ast.body.MethodDeclaration -import com.github.javaparser.ast.visitor.VoidVisitorAdapter -import com.intellij.lang.java.JavaLanguage -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.openapi.project.Project -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiFile -import com.intellij.psi.PsiFileFactory -import com.intellij.psi.codeStyle.CodeStyleManager -import org.jetbrains.research.testspark.core.data.TestGenerationData -import java.io.File - -object JavaClassBuilderHelper { - /** - * Generates the code for a test class. - * - * @param className the name of the test class - * @param body the body of the test class - * @return the generated code as a string - */ - fun generateCode( - project: Project, - className: String, - body: String, - imports: Set, - packageString: String, - runWith: String, - otherInfo: String, - testGenerationData: TestGenerationData, - ): String { - var testFullText = printUpperPart(className, imports, packageString, runWith, otherInfo) - - // Add each test (exclude expected exception) - testFullText += body - - // close the test class - testFullText += "}" - - testFullText.replace("\r\n", "\n") - - /** - * for better readability and make the tests shorter, we reduce the number of line breaks: - * when we have three or more sequential \n, reduce it to two. - */ - return formatJavaCode(project, Regex("\n\n\n(\n)*").replace(testFullText, "\n\n"), testGenerationData) - } - - /** - * Returns the upper part of test suite (package name, imports, and test class name) as a string. - * - * @return the upper part of test suite (package name, imports, and test class name) as a string. - */ - private fun printUpperPart( - className: String, - imports: Set, - packageString: String, - runWith: String, - otherInfo: String, - ): String { - var testText = "" - - // Add package - if (packageString.isNotBlank()) { - testText += "package $packageString;\n" - } - - // add imports - imports.forEach { importedElement -> - testText += "$importedElement\n" - } - - testText += "\n" - - // add runWith if exists - if (runWith.isNotBlank()) { - testText += "@RunWith($runWith)\n" - } - // open the test class - testText += "public class $className {\n\n" - - // Add other presets (annotations, non-test functions) - if (otherInfo.isNotBlank()) { - testText += otherInfo - } - - return testText - } - - /** - * Finds the test method from a given class with the specified test case name. - * - * @param code The code of the class containing test methods. - * @return The test method as a string, including the "@Test" annotation. - */ - fun getTestMethodCodeFromClassWithTestCase(code: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(code) - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result += "\t" + method.toString().replace("\n", "\n\t") + "\n\n" - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - val upperCutCode = "\t@Test" + code.split("@Test").last() - var methodStarted = false - var balanceOfBrackets = 0 - for (symbol in upperCutCode) { - result += symbol - if (symbol == '{') { - methodStarted = true - balanceOfBrackets++ - } - if (symbol == '}') { - balanceOfBrackets-- - } - if (methodStarted && balanceOfBrackets == 0) { - break - } - } - return result + "\n" - } - } - - /** - * Retrieves the name of the test method from a given Java class with test cases. - * - * @param oldTestCaseName The old name of test case - * @param code The source code of the Java class with test cases. - * @return The name of the test method. If no test method is found, an empty string is returned. - */ - fun getTestMethodNameFromClassWithTestCase(oldTestCaseName: String, code: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(code) - - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result = method.nameAsString - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - return oldTestCaseName - } - } - - /** - * Retrieves the class name from the given test case code. - * - * @param code The test case code to extract the class name from. - * @return The class name extracted from the test case code. - */ - fun getClassFromTestCaseCode(code: String): String { - val pattern = Regex("public\\s+class\\s+(\\S+)\\s*\\{") - val matchResult = pattern.find(code) - matchResult ?: return "GeneratedTest" - val (className) = matchResult.destructured - return className - } - - /** - * Formats the given Java code using IntelliJ IDEA's code formatting rules. - * - * @param code The Java code to be formatted. - * @return The formatted Java code. - */ - fun formatJavaCode(project: Project, code: String, generatedTestData: TestGenerationData): String { - var result = "" - WriteCommandAction.runWriteCommandAction(project) { - val fileName = generatedTestData.resultPath + File.separatorChar + "Formatted.java" - // create a temporary PsiFile - val psiFile: PsiFile = PsiFileFactory.getInstance(project) - .createFileFromText( - fileName, - JavaLanguage.INSTANCE, - code, - ) - - CodeStyleManager.getInstance(project).reformat(psiFile) - - val document = PsiDocumentManager.getInstance(project).getDocument(psiFile) - result = document?.text ?: code - - File(fileName).delete() - } - - return result - } -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt index b36fe381a..d10525087 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt @@ -12,15 +12,19 @@ import org.jetbrains.research.testspark.core.generation.llm.executeTestCaseModif import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestBodyPrinterFactory +import org.jetbrains.research.testspark.tools.TestSuiteParserFactory +import org.jetbrains.research.testspark.tools.TestsAssemblerFactory import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager -import org.jetbrains.research.testspark.tools.llm.generation.JUnitTestsAssembler import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieInfo import org.jetbrains.research.testspark.tools.llm.generation.grazie.GraziePlatform +import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFacePlatform import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIPlatform import java.net.HttpURLConnection import javax.swing.DefaultComboBoxModel @@ -67,6 +71,9 @@ object LLMHelper { if (platformSelector.selectedItem!!.toString() == settingsState.grazieName) { models = getGrazieModels() } + if (platformSelector.selectedItem!!.toString() == settingsState.huggingFaceName) { + models = getHuggingFaceModels() + } modelSelector.model = DefaultComboBoxModel(models) for (index in llmPlatforms.indices) { if (llmPlatforms[index].name == settingsState.openAIName && @@ -81,6 +88,12 @@ object LLMHelper { modelSelector.selectedItem = settingsState.grazieModel llmPlatforms[index].model = modelSelector.selectedItem!!.toString() } + if (llmPlatforms[index].name == settingsState.huggingFaceName && + llmPlatforms[index].name == platformSelector.selectedItem!!.toString() + ) { + modelSelector.selectedItem = settingsState.huggingFaceModel + llmPlatforms[index].model = modelSelector.selectedItem!!.toString() + } } modelSelector.isEnabled = true if (models.contentEquals(arrayOf(""))) modelSelector.isEnabled = false @@ -112,6 +125,12 @@ object LLMHelper { llmUserTokenField.text = settingsState.grazieToken llmPlatforms[index].token = settingsState.grazieToken } + if (llmPlatforms[index].name == settingsState.huggingFaceName && + llmPlatforms[index].name == platformSelector.selectedItem!!.toString() + ) { + llmUserTokenField.text = settingsState.huggingFaceToken + llmPlatforms[index].token = settingsState.huggingFaceToken + } } } @@ -185,8 +204,6 @@ object LLMHelper { if (isGrazieClassLoaded()) { platformSelector.model = DefaultComboBoxModel(llmPlatforms.map { it.name }.toTypedArray()) platformSelector.selectedItem = settingsState.currentLLMPlatformName - } else { - platformSelector.isEnabled = false } llmUserTokenField.toolTipText = LLMSettingsBundle.get("llmToken") @@ -202,7 +219,7 @@ object LLMHelper { * @return The list of LLMPlatforms. */ fun getLLLMPlatforms(): List { - return listOf(OpenAIPlatform(), GraziePlatform()) + return listOf(OpenAIPlatform(), GraziePlatform(), HuggingFacePlatform()) } /** @@ -230,7 +247,7 @@ object LLMHelper { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun testModificationRequest( - language: Language, + language: SupportedLanguage, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -244,13 +261,28 @@ object LLMHelper { return null } + val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + val testBodyPrinter = TestBodyPrinterFactory.create(language) + val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( + jUnitVersion, + language, + testBodyPrinter, + ) + + val testsAssembler = TestsAssemblerFactory.create( + indicator, + testGenerationOutput, + testSuiteParser, + jUnitVersion, + ) + val testSuite = executeTestCaseModificationRequest( language, testCase, task, indicator, requestManager, - testsAssembler = JUnitTestsAssembler(project, indicator, testGenerationOutput), + testsAssembler, errorMonitor, ) return testSuite @@ -328,4 +360,13 @@ object LLMHelper { arrayOf("") } } + + /** + * Retrieves the available HuggingFace models. + * + * @return an array of string representing the available HuggingFace models + */ + private fun getHuggingFaceModels(): Array { + return arrayOf("Meta-Llama-3-8B-Instruct", "Meta-Llama-3-70B-Instruct") + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt new file mode 100644 index 000000000..b20891ed4 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeAnalyzer.kt @@ -0,0 +1,39 @@ +package org.jetbrains.research.testspark.helpers + +/** + * Interface for retrieving information from test class code. + */ +interface TestClassCodeAnalyzer { + /** + * Extracts the code of the first test method found in the given class code. + * + * @param classCode The code of the class containing test methods. + * @return The code of the first test method as a string, including the "@Test" annotation. + */ + fun extractFirstTestMethodCode(classCode: String): String + + /** + * Retrieves the name of the first test method found in the given class code. + * + * @param oldTestCaseName The old name of a test case + * @param classCode The source code of the class containing test methods. + * @return The name of the first test method. If no test method is found, an empty string is returned. + */ + fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String + + /** + * Retrieves the class name from the given test case code. + * + * @param code the test case code to extract the class name from + * @return the class name extracted from the test case code + */ + fun getClassFromTestCaseCode(code: String): String + + /** + * Return the right file name from the given test case code. + * + * @param code the test case code to extract the class name from + * @return the class name extracted from the test case code + */ + fun getFileNameFromTestCaseCode(code: String): String +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt new file mode 100644 index 000000000..7443b1664 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/TestClassCodeGenerator.kt @@ -0,0 +1,43 @@ +package org.jetbrains.research.testspark.helpers + +import com.intellij.openapi.project.Project +import org.jetbrains.research.testspark.core.data.TestGenerationData + +/** + * Interface for generating and formatting test class code. + */ +interface TestClassCodeGenerator { + /** + * Generates the code for a test class. + * + * @param project the current project + * @param className the name of the test class + * @param body the body of the test class + * @param imports the set of imports needed in the test class + * @param packageString the package declaration of the test class + * @param runWith the runWith annotation for the test class + * @param otherInfo any other additional information for the test class + * @param testGenerationData the data used for test generation + * @return the generated code as a string + */ + fun generateCode( + project: Project, + className: String, + body: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + testGenerationData: TestGenerationData, + ): String + + /** + * Formats the given Java code using IntelliJ IDEA's code formatting rules. + * + * @param project the current project + * @param code the Java code to be formatted + * @param generatedTestData the data used for generating the test + * @return the formatted Java code + */ + fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt new file mode 100644 index 000000000..f6f2fd0a9 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeAnalyzer.kt @@ -0,0 +1,78 @@ +package org.jetbrains.research.testspark.helpers.java + +import com.github.javaparser.ParseProblemException +import com.github.javaparser.StaticJavaParser +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.body.MethodDeclaration +import com.github.javaparser.ast.visitor.VoidVisitorAdapter +import org.jetbrains.research.testspark.helpers.TestClassCodeAnalyzer + +object JavaTestClassCodeAnalyzer : TestClassCodeAnalyzer { + + override fun extractFirstTestMethodCode(classCode: String): String { + var result = "" + try { + val componentUnit: CompilationUnit = StaticJavaParser.parse(classCode) + object : VoidVisitorAdapter() { + override fun visit(method: MethodDeclaration, arg: Any?) { + super.visit(method, arg) + if (method.getAnnotationByName("Test").isPresent) { + result += "\t" + method.toString().replace("\n", "\n\t") + "\n\n" + } + } + }.visit(componentUnit, null) + + return result + } catch (e: ParseProblemException) { + val upperCutCode = "\t@Test" + classCode.split("@Test").last() + var methodStarted = false + var balanceOfBrackets = 0 + for (symbol in upperCutCode) { + result += symbol + if (symbol == '{') { + methodStarted = true + balanceOfBrackets++ + } + if (symbol == '}') { + balanceOfBrackets-- + } + if (methodStarted && balanceOfBrackets == 0) { + break + } + } + return result + "\n" + } + } + + override fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String { + var result = "" + try { + val componentUnit: CompilationUnit = StaticJavaParser.parse(classCode) + + object : VoidVisitorAdapter() { + override fun visit(method: MethodDeclaration, arg: Any?) { + super.visit(method, arg) + if (method.getAnnotationByName("Test").isPresent) { + result = method.nameAsString + } + } + }.visit(componentUnit, null) + + return result + } catch (e: ParseProblemException) { + return oldTestCaseName + } + } + + override fun getClassFromTestCaseCode(code: String): String { + val pattern = Regex("public\\s+class\\s+(\\S+)\\s*\\{") + val matchResult = pattern.find(code) + matchResult ?: return "GeneratedTest" + val (className) = matchResult.destructured + return className + } + + override fun getFileNameFromTestCaseCode(code: String): String { + return "${getClassFromTestCaseCode(code)}.java" + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt new file mode 100644 index 000000000..46c071d5f --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/java/JavaTestClassCodeGenerator.kt @@ -0,0 +1,104 @@ +package org.jetbrains.research.testspark.helpers.java + +import com.intellij.lang.java.JavaLanguage +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.diagnostic.Logger +import com.intellij.openapi.project.Project +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiFile +import com.intellij.psi.PsiFileFactory +import com.intellij.psi.codeStyle.CodeStyleManager +import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.helpers.TestClassCodeGenerator +import java.io.File + +object JavaTestClassCodeGenerator : TestClassCodeGenerator { + + private val log = Logger.getInstance(this::class.java) + + override fun generateCode( + project: Project, + className: String, + body: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + testGenerationData: TestGenerationData, + ): String { + var testFullText = printUpperPart(className, imports, packageString, runWith, otherInfo) + + // Add each test (exclude expected exception) + testFullText += body + + // close the test class + testFullText += "}" + + testFullText.replace("\r\n", "\n") + + /** + * for better readability and make the tests shorter, we reduce the number of line breaks: + * when we have three or more sequential \n, reduce it to two. + */ + return formatCode(project, Regex("\n\n\n(?:\n)*").replace(testFullText, "\n\n"), testGenerationData) + } + + override fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String { + var result = "" + WriteCommandAction.runWriteCommandAction(project) { + val fileName = generatedTestData.resultPath + File.separatorChar + "Formatted.java" + // create a temporary PsiFile + val psiFile: PsiFile = PsiFileFactory.getInstance(project) + .createFileFromText( + fileName, + JavaLanguage.INSTANCE, + code, + ) + + CodeStyleManager.getInstance(project).reformat(psiFile) + + val document = PsiDocumentManager.getInstance(project).getDocument(psiFile) + result = document?.text ?: code + + File(fileName).delete() + } + + return result + } + + private fun printUpperPart( + className: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + ): String { + var testText = "" + + // Add package + if (packageString.isNotBlank()) { + testText += "package $packageString;\n" + } + + // add imports + imports.forEach { importedElement -> + testText += "$importedElement\n" + } + + testText += "\n" + + // add runWith if exists + if (runWith.isNotBlank()) { + testText += "@RunWith($runWith)\n" + } + // open the test class + testText += "public class $className {\n\n" + + // Add other presets (annotations, non-test functions) + if (otherInfo.isNotBlank()) { + testText += otherInfo + } + + return testText + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt new file mode 100644 index 000000000..b21a97dfd --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeAnalyzer.kt @@ -0,0 +1,65 @@ +package org.jetbrains.research.testspark.helpers.kotlin + +import org.jetbrains.research.testspark.helpers.TestClassCodeAnalyzer + +object KotlinTestClassCodeAnalyzer : TestClassCodeAnalyzer { + + override fun extractFirstTestMethodCode(classCode: String): String { + val testMethods = StringBuilder() + val lines = classCode.lines() + + var methodStarted = false + var balanceOfBrackets = 0 + + for (line in lines) { + if (!methodStarted && line.contains("@Test")) { + methodStarted = true + testMethods.append(line).append("\n") + } else if (methodStarted) { + testMethods.append(line).append("\n") + for (char in line) { + if (char == '{') { + balanceOfBrackets++ + } else if (char == '}') { + balanceOfBrackets-- + } + } + if (balanceOfBrackets == 0) { + methodStarted = false + testMethods.append("\n") + } + } + } + + return testMethods.toString() + } + + override fun extractFirstTestMethodName(oldTestCaseName: String, classCode: String): String { + val lines = classCode.lines() + var testMethodName = oldTestCaseName + + for (line in lines) { + if (line.contains("@Test")) { + val methodDeclarationLine = lines[lines.indexOf(line) + 1] + val matchResult = Regex("fun\\s+(\\w+)\\s*\\(").find(methodDeclarationLine) + if (matchResult != null) { + testMethodName = matchResult.groupValues[1] + } + break + } + } + return testMethodName + } + + override fun getClassFromTestCaseCode(code: String): String { + val pattern = Regex("class\\s+(\\S+)\\s*\\{") + val matchResult = pattern.find(code) + matchResult ?: return "GeneratedTest" + val (className) = matchResult.destructured + return className + } + + override fun getFileNameFromTestCaseCode(code: String): String { + return "${getClassFromTestCaseCode(code)}.kt" + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt new file mode 100644 index 000000000..eb10a7aa9 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/kotlin/KotlinTestClassCodeGenerator.kt @@ -0,0 +1,101 @@ +package org.jetbrains.research.testspark.helpers.kotlin + +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.diagnostic.Logger +import com.intellij.openapi.project.Project +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiFile +import com.intellij.psi.PsiFileFactory +import com.intellij.psi.codeStyle.CodeStyleManager +import org.jetbrains.kotlin.idea.KotlinLanguage +import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.helpers.TestClassCodeGenerator +import java.io.File + +object KotlinTestClassCodeGenerator : TestClassCodeGenerator { + + private val log = Logger.getInstance(this::class.java) + + override fun generateCode( + project: Project, + className: String, + body: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + testGenerationData: TestGenerationData, + ): String { + log.debug("[KotlinClassBuilderHelper] Generate code for $className") + + var testFullText = + printUpperPart(className, imports, packageString, runWith, otherInfo) + + // Add each test (exclude expected exception) + testFullText += body + + // close the test class + testFullText += "}" + + testFullText.replace("\r\n", "\n") + + // Reduce the number of line breaks for better readability + return formatCode(project, Regex("\n\n\n(?:\n)*").replace(testFullText, "\n\n"), testGenerationData) + } + + override fun formatCode(project: Project, code: String, generatedTestData: TestGenerationData): String { + var result = "" + WriteCommandAction.runWriteCommandAction(project) { + val fileName = generatedTestData.resultPath + File.separatorChar + "Formatted.kt" + // Create a temporary PsiFile + val psiFile: PsiFile = PsiFileFactory.getInstance(project) + .createFileFromText(fileName, KotlinLanguage.INSTANCE, code) + + CodeStyleManager.getInstance(project).reformat(psiFile) + + val document = PsiDocumentManager.getInstance(project).getDocument(psiFile) + result = document?.text ?: code + + File(fileName).delete() + } + log.info("Formatted result class: $result") + return result + } + + private fun printUpperPart( + className: String, + imports: Set, + packageString: String, + runWith: String, + otherInfo: String, + ): String { + var testText = "" + + // Add package + if (packageString.isNotBlank()) { + testText += "package $packageString\n" + } + + // Add imports + imports.forEach { importedElement -> + testText += "$importedElement\n" + } + + testText += "\n" + + // Add runWith if exists + if (runWith.isNotBlank()) { + testText += "@RunWith($runWith::class)\n" + } + + // Open the test class + testText += "class $className {\n\n" + + // Add other presets (annotations, non-test functions) + if (otherInfo.isNotBlank()) { + testText += otherInfo + } + + return testText + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt index e3b11555a..6b257f421 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt @@ -1,425 +1,69 @@ package org.jetbrains.research.testspark.services -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.openapi.components.Service -import com.intellij.openapi.components.service -import com.intellij.openapi.fileChooser.FileChooser -import com.intellij.openapi.fileChooser.FileChooserDescriptor -import com.intellij.openapi.fileEditor.FileDocumentManager -import com.intellij.openapi.fileEditor.FileEditorManager -import com.intellij.openapi.fileEditor.OpenFileDescriptor -import com.intellij.openapi.fileEditor.TextEditor -import com.intellij.openapi.project.Project -import com.intellij.openapi.vfs.LocalFileSystem -import com.intellij.openapi.vfs.VirtualFile -import com.intellij.openapi.vfs.VirtualFileManager -import com.intellij.openapi.wm.ToolWindowManager -import com.intellij.psi.PsiClass -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiElementFactory -import com.intellij.psi.PsiJavaFile -import com.intellij.psi.PsiManager -import com.intellij.refactoring.suggested.startOffset +import com.intellij.psi.PsiFile import com.intellij.ui.EditorTextField -import com.intellij.ui.JBColor -import com.intellij.ui.components.JBScrollPane -import com.intellij.ui.content.Content -import com.intellij.ui.content.ContentFactory -import com.intellij.ui.content.ContentManager -import com.intellij.util.containers.stream -import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle -import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle import org.jetbrains.research.testspark.core.data.Report -import org.jetbrains.research.testspark.core.data.TestCase -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.data.UIContext -import org.jetbrains.research.testspark.display.TestCasePanelFactory -import org.jetbrains.research.testspark.display.TopButtonsPanelFactory -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper -import org.jetbrains.research.testspark.helpers.ReportHelper -import java.awt.BorderLayout -import java.awt.Color -import java.awt.Dimension -import java.io.File -import java.util.Locale -import javax.swing.Box -import javax.swing.BoxLayout -import javax.swing.JButton -import javax.swing.JCheckBox -import javax.swing.JOptionPane +import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import javax.swing.JPanel -import javax.swing.JSeparator -import javax.swing.SwingConstants -@Service(Service.Level.PROJECT) -class TestCaseDisplayService(private val project: Project) { - private var report: Report? = null - - private val unselectedTestCases = HashMap() - - private var mainPanel: JPanel = JPanel() - - private val topButtonsPanelFactory = TopButtonsPanelFactory(project) - - private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) - - private var allTestCasePanel: JPanel = JPanel() - - private var scrollPane: JBScrollPane = JBScrollPane( - allTestCasePanel, - JBScrollPane.VERTICAL_SCROLLBAR_ALWAYS, - JBScrollPane.HORIZONTAL_SCROLLBAR_NEVER, - ) - - private var testCasePanels: HashMap = HashMap() - - private var testsSelected: Int = 0 - - /** - * Default color for the editors in the tool window - */ - private var defaultEditorColor: Color? = null - - /** - * Content Manager to be able to add / remove tabs from tool window - */ - private var contentManager: ContentManager? = null - - /** - * Variable to keep reference to the coverage visualisation content - */ - private var content: Content? = null - - var uiContext: UIContext? = null - - init { - allTestCasePanel.layout = BoxLayout(allTestCasePanel, BoxLayout.Y_AXIS) - mainPanel.layout = BorderLayout() - - mainPanel.add(topButtonsPanelFactory.getPanel(), BorderLayout.NORTH) - mainPanel.add(scrollPane, BorderLayout.CENTER) - - applyButton.isOpaque = false - applyButton.isContentAreaFilled = false - mainPanel.add(applyButton, BorderLayout.SOUTH) - - applyButton.addActionListener { applyTests() } - } +interface TestCaseDisplayService { /** * Fill the panel with the generated test cases. Remove all previously shown test cases. * Add Tests and their names to a List of pairs (used for highlighting) */ - fun displayTestCases(report: Report, uiContext: UIContext, language: Language) { - this.report = report - this.uiContext = uiContext - - val editor = project.service().editor!! - - allTestCasePanel.removeAll() - testCasePanels.clear() - - addSeparator() - - // TestCasePanelFactories array - val testCasePanelFactories = arrayListOf() - - report.testCaseList.values.forEach { - val testCase = it - val testCasePanel = JPanel() - testCasePanel.layout = BorderLayout() - - // Add a checkbox to select the test - val checkbox = JCheckBox() - checkbox.isSelected = true - checkbox.addItemListener { - // Update the number of selected tests - testsSelected -= (1 - 2 * checkbox.isSelected.compareTo(false)) - - if (checkbox.isSelected) { - ReportHelper.selectTestCase(project, report, unselectedTestCases, testCase.id) - } else { - ReportHelper.unselectTestCase(project, report, unselectedTestCases, testCase.id) - } - - updateUI() - } - testCasePanel.add(checkbox, BorderLayout.WEST) - - val testCasePanelFactory = TestCasePanelFactory(project, language, testCase, editor, checkbox, uiContext, report) - testCasePanel.add(testCasePanelFactory.getUpperPanel(), BorderLayout.NORTH) - testCasePanel.add(testCasePanelFactory.getMiddlePanel(), BorderLayout.CENTER) - testCasePanel.add(testCasePanelFactory.getBottomPanel(), BorderLayout.SOUTH) - - testCasePanelFactories.add(testCasePanelFactory) - - testCasePanel.add(Box.createRigidArea(Dimension(12, 0)), BorderLayout.EAST) - - // Add panel to parent panel - testCasePanel.maximumSize = Dimension(Short.MAX_VALUE.toInt(), Short.MAX_VALUE.toInt()) - allTestCasePanel.add(testCasePanel) - addSeparator() - testCasePanels[testCase.testName] = testCasePanel - } - - // Update the number of selected tests (all tests are selected by default) - testsSelected = testCasePanels.size - - topButtonsPanelFactory.setTestCasePanelFactoriesArray(testCasePanelFactories) - topButtonsPanelFactory.updateTopLabels() - - createToolWindowTab() - } + fun displayTestCases(report: Report, uiContext: UIContext, language: SupportedLanguage) /** * Adds a separator to the allTestCasePanel. */ - private fun addSeparator() { - allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) - allTestCasePanel.add(JSeparator(SwingConstants.HORIZONTAL)) - allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) - } + fun addSeparator() /** * Highlight the mini-editor in the tool window whose name corresponds with the name of the test provided * * @param name name of the test whose editor should be highlighted */ - fun highlightTestCase(name: String) { - val myPanel = testCasePanels[name] ?: return - openToolWindowTab() - scrollToPanel(myPanel) - - val editor = getEditor(name) ?: return - val settingsProjectState = project.service().state - val highlightColor = - JBColor( - PluginSettingsBundle.get("colorName"), - Color( - settingsProjectState.colorRed, - settingsProjectState.colorGreen, - settingsProjectState.colorBlue, - 30, - ), - ) - if (editor.background.equals(highlightColor)) return - defaultEditorColor = editor.background - editor.background = highlightColor - returnOriginalEditorBackground(editor) - } + fun highlightTestCase(name: String) /** * Method to open the toolwindow tab with generated tests if not already open. */ - private fun openToolWindowTab() { - val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") - contentManager = toolWindowManager!!.contentManager - if (content != null) { - toolWindowManager.show() - toolWindowManager.contentManager.setSelectedContent(content!!) - } - } + fun openToolWindowTab() /** * Scrolls to the highlighted panel. * * @param myPanel the panel to scroll to */ - private fun scrollToPanel(myPanel: JPanel) { - var sum = 0 - for (component in allTestCasePanel.components) { - if (component == myPanel) { - break - } else { - sum += component.height - } - } - val scroll = scrollPane.verticalScrollBar - scroll.value = (scroll.minimum + scroll.maximum) * sum / allTestCasePanel.height - } + fun scrollToPanel(myPanel: JPanel) /** * Removes all coverage highlighting from the editor. */ - private fun removeAllHighlights() { - project.service().editor?.markupModel?.removeAllHighlighters() - } + fun removeAllHighlights() /** * Reset the provided editors color to the default (initial) one after 10 seconds * @param editor the editor whose color to change */ - private fun returnOriginalEditorBackground(editor: EditorTextField) { - Thread { - Thread.sleep(10000) - editor.background = defaultEditorColor - }.start() - } + fun returnOriginalEditorBackground(editor: EditorTextField) /** * Highlight a range of editors * @param names list of test names to pass to highlight function */ - fun highlightCoveredMutants(names: List) { - names.forEach { - highlightTestCase(it) - } - } + fun highlightCoveredMutants(names: List) /** * Show a dialog where the user can select what test class the tests should be applied to, * and apply the selected tests to the test class. */ - private fun applyTests() { - // Filter the selected test cases - val selectedTestCasePanels = testCasePanels.filter { (it.value.getComponent(0) as JCheckBox).isSelected } - val selectedTestCases = selectedTestCasePanels.map { it.key } - - // Get the test case components (source code of the tests) - val testCaseComponents = selectedTestCases - .map { getEditor(it)!! } - .map { it.document.text } - - // Descriptor for choosing folders and java files - val descriptor = FileChooserDescriptor(true, true, false, false, false, false) - - // Apply filter with folders and java files with main class - WriteCommandAction.runWriteCommandAction(project) { - descriptor.withFileFilter { file -> - file.isDirectory || ( - file.extension?.lowercase(Locale.getDefault()) == "java" && ( - PsiManager.getInstance(project).findFile(file!!) as PsiJavaFile - ).classes.stream().map { it.name } - .toArray() - .contains( - ( - PsiManager.getInstance(project) - .findFile(file) as PsiJavaFile - ).name.removeSuffix(".java"), - ) - ) - } - } - - val fileChooser = FileChooser.chooseFiles( - descriptor, - project, - LocalFileSystem.getInstance().findFileByPath(project.basePath!!), - ) - - /** - * Cancel button pressed - */ - if (fileChooser.isEmpty()) return - - /** - * Chosen files by user - */ - val chosenFile = fileChooser[0] - - /** - * Virtual file of a final java file - */ - var virtualFile: VirtualFile? = null - - /** - * PsiClass of a final java file - */ - var psiClass: PsiClass? = null - - /** - * PsiJavaFile of a final java file - */ - var psiJavaFile: PsiJavaFile? = null - - if (chosenFile.isDirectory) { - // Input new file data - var className: String - var fileName: String - var filePath: String - // Waiting for correct file name input - while (true) { - val jOptionPane = - JOptionPane.showInputDialog( - null, - PluginLabelsBundle.get("optionPaneMessage"), - PluginLabelsBundle.get("optionPaneTitle"), - JOptionPane.PLAIN_MESSAGE, - null, - null, - null, - ) - - // Cancel button pressed - jOptionPane ?: return - - // Get class name from user - className = jOptionPane as String - - // Set file name and file path - fileName = "${className.split('.')[0]}.java" - filePath = "${chosenFile.path}/$fileName" - - // Check the correctness of a class name - if (!Regex("[A-Z][a-zA-Z0-9]*(.java)?").matches(className)) { - showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage")) - continue - } - - // Check the existence of a file with this name - if (File(filePath).exists()) { - showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage")) - continue - } - break - } - - // Create new file and set services of this file - WriteCommandAction.runWriteCommandAction(project) { - chosenFile.createChildData(null, fileName) - virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!! - psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) - psiClass = PsiElementFactory.getInstance(project).createClass(className.split(".")[0]) + fun applyTests() - if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { - psiClass!!.modifierList!!.addAnnotation("RunWith(${uiContext!!.testGenerationOutput.runWith})") - } - - psiJavaFile!!.add(psiClass!!) - } - } else { - // Set services of the chosen file - virtualFile = chosenFile - psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) - psiClass = psiJavaFile!!.classes[ - psiJavaFile!!.classes.stream().map { it.name }.toArray() - .indexOf(psiJavaFile!!.name.removeSuffix(".java")), - ] - } - - // Add tests to the file - WriteCommandAction.runWriteCommandAction(project) { - appendTestsToClass(testCaseComponents, psiClass!!, psiJavaFile!!) - } - - // Remove the selected test cases from the cache and the tool window UI - removeSelectedTestCases(selectedTestCasePanels) - - // Open the file after adding - FileEditorManager.getInstance(project).openTextEditor( - OpenFileDescriptor(project, virtualFile!!), - true, - ) - } - - private fun showErrorWindow(message: String) { - JOptionPane.showMessageDialog( - null, - message, - PluginLabelsBundle.get("errorWindowTitle"), - JOptionPane.ERROR_MESSAGE, - ) - } + fun showErrorWindow(message: String) /** * Retrieve the editor corresponding to a particular test case @@ -427,11 +71,7 @@ class TestCaseDisplayService(private val project: Project) { * @param testCaseName the name of the test case * @return the editor corresponding to the test case, or null if it does not exist */ - fun getEditor(testCaseName: String): EditorTextField? { - val middlePanelComponent = testCasePanels[testCaseName]?.getComponent(2) ?: return null - val middlePanel = middlePanelComponent as JPanel - return (middlePanel.getComponent(1) as JBScrollPane).viewport.view as EditorTextField - } + fun getEditor(testCaseName: String): EditorTextField? /** * Append the provided test cases to the provided class. @@ -440,107 +80,23 @@ class TestCaseDisplayService(private val project: Project) { * @param selectedClass the class which the test cases should be appended to * @param outputFile the output file for tests */ - private fun appendTestsToClass(testCaseComponents: List, selectedClass: PsiClass, outputFile: PsiJavaFile) { - // block document - PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument( - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!, - ) - - // insert tests to a code - testCaseComponents.reversed().forEach { - val testMethodCode = - JavaClassBuilderHelper.getTestMethodCodeFromClassWithTestCase( - JavaClassBuilderHelper.formatJavaCode( - project, - it.replace("\r\n", "\n") - .replace("verifyException(", "// verifyException("), - uiContext!!.testGenerationOutput, - ), - ) - // Fix Windows line separators - .replace("\r\n", "\n") - - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - selectedClass.rBrace!!.textRange.startOffset, - testMethodCode, - ) - } - - // insert other info to a code - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - selectedClass.rBrace!!.textRange.startOffset, - uiContext!!.testGenerationOutput.otherInfo + "\n", - ) - - // insert imports to a code - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - outputFile.importList?.startOffset ?: outputFile.packageStatement?.startOffset ?: 0, - uiContext!!.testGenerationOutput.importsCode.joinToString("\n") + "\n\n", - ) - - // insert package to a code - outputFile.packageStatement ?: PsiDocumentManager.getInstance(project).getDocument(outputFile)!! - .insertString( - 0, - if (uiContext!!.testGenerationOutput.packageLine.isEmpty()) { - "" - } else { - "package ${uiContext!!.testGenerationOutput.packageLine};\n\n" - }, - ) - } + fun appendTestsToClass(testCaseComponents: List, selectedClass: PsiClassWrapper, outputFile: PsiFile) /** * Utility function that returns the editor for a specific file url, * in case it is opened in the IDE */ - fun updateEditorForFileUrl(fileUrl: String) { - val documentManager = FileDocumentManager.getInstance() - // https://intellij-support.jetbrains.com/hc/en-us/community/posts/360004480599/comments/360000703299 - FileEditorManager.getInstance(project).selectedEditors.map { it as TextEditor }.map { it.editor }.map { - val currentFile = documentManager.getFile(it.document) - if (currentFile != null) { - if (currentFile.presentableUrl == fileUrl) { - project.service().editor = it - } - } - } - } + fun updateEditorForFileUrl(fileUrl: String) /** * Creates a new toolWindow tab for the coverage visualisation. */ - private fun createToolWindowTab() { - // Remove generated tests tab from content manager if necessary - val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") - contentManager = toolWindowManager!!.contentManager - if (content != null) { - contentManager!!.removeContent(content!!, true) - } - - // If there is no generated tests tab, make it - val contentFactory: ContentFactory = ContentFactory.getInstance() - content = contentFactory.createContent( - mainPanel, - PluginLabelsBundle.get("generatedTests"), - true, - ) - contentManager!!.addContent(content!!) - - // Focus on generated tests tab and open toolWindow if not opened already - contentManager!!.setSelectedContent(content!!) - toolWindowManager.show() - } + fun createToolWindowTab() /** * Closes the tool window and destroys the content of the tab. */ - private fun closeToolWindow() { - contentManager?.removeContent(content!!, true) - ToolWindowManager.getInstance(project).getToolWindow("TestSpark")?.hide() - val coverageVisualisationService = project.service() - coverageVisualisationService.closeToolWindowTab() - } + fun closeToolWindow() /** * Removes the selected tests from the cache, removes all the highlights from the editor and closes the tool window. @@ -549,37 +105,16 @@ class TestCaseDisplayService(private val project: Project) { * * @param selectedTestCasePanels the panels of the selected tests */ - private fun removeSelectedTestCases(selectedTestCasePanels: Map) { - selectedTestCasePanels.forEach { removeTestCase(it.key) } - removeAllHighlights() - closeToolWindow() - } - - fun clear() { - // Remove the tests - val testCasePanelsToRemove = testCasePanels.toMap() - removeSelectedTestCases(testCasePanelsToRemove) + fun removeSelectedTestCases(selectedTestCasePanels: Map) - topButtonsPanelFactory.clear() - } + fun clear() /** * A helper method to remove a test case from the cache and from the UI. * * @param testCaseName the name of the test */ - fun removeTestCase(testCaseName: String) { - // Update the number of selected test cases if necessary - if ((testCasePanels[testCaseName]!!.getComponent(0) as JCheckBox).isSelected) { - testsSelected-- - } - - // Remove the test panel from the UI - allTestCasePanel.remove(testCasePanels[testCaseName]) - - // Remove the test panel - testCasePanels.remove(testCaseName) - } + fun removeTestCase(testCaseName: String) /** * Updates the user interface of the tool window. @@ -589,36 +124,26 @@ class TestCaseDisplayService(private val project: Project) { * of the topButtonsPanel object. It also checks if there are no more tests remaining * and closes the tool window if that is the case. */ - fun updateUI() { - // Update the UI of the tool window tab - allTestCasePanel.updateUI() - - topButtonsPanelFactory.updateTopLabels() - - // If no more tests are remaining, close the tool window - if (testCasePanels.size == 0) closeToolWindow() - } + fun updateUI() /** * Retrieves the list of test case panels. * * @return The list of test case panels. */ - fun getTestCasePanels() = testCasePanels + fun getTestCasePanels(): HashMap /** * Retrieves the currently selected tests. * * @return The list of tests currently selected. */ - fun getTestsSelected() = testsSelected + fun getTestsSelected(): Int /** * Sets the number of tests selected. * * @param testsSelected The number of tests selected. */ - fun setTestsSelected(testsSelected: Int) { - this.testsSelected = testsSelected - } + fun setTestsSelected(testsSelected: Int) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt new file mode 100644 index 000000000..0dbc5009c --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt @@ -0,0 +1,544 @@ +package org.jetbrains.research.testspark.services.java + +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.components.Service +import com.intellij.openapi.components.service +import com.intellij.openapi.fileChooser.FileChooser +import com.intellij.openapi.fileChooser.FileChooserDescriptor +import com.intellij.openapi.fileEditor.FileDocumentManager +import com.intellij.openapi.fileEditor.FileEditorManager +import com.intellij.openapi.fileEditor.OpenFileDescriptor +import com.intellij.openapi.fileEditor.TextEditor +import com.intellij.openapi.project.Project +import com.intellij.openapi.vfs.LocalFileSystem +import com.intellij.openapi.vfs.VirtualFile +import com.intellij.openapi.vfs.VirtualFileManager +import com.intellij.openapi.wm.ToolWindowManager +import com.intellij.psi.PsiClass +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiElementFactory +import com.intellij.psi.PsiFile +import com.intellij.psi.PsiJavaFile +import com.intellij.psi.PsiManager +import com.intellij.refactoring.suggested.startOffset +import com.intellij.ui.EditorTextField +import com.intellij.ui.JBColor +import com.intellij.ui.components.JBScrollPane +import com.intellij.ui.content.Content +import com.intellij.ui.content.ContentFactory +import com.intellij.ui.content.ContentManager +import com.intellij.util.containers.stream +import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle +import org.jetbrains.research.testspark.core.data.Report +import org.jetbrains.research.testspark.core.data.TestCase +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.data.UIContext +import org.jetbrains.research.testspark.display.TestCasePanelFactory +import org.jetbrains.research.testspark.display.TopButtonsPanelFactory +import org.jetbrains.research.testspark.helpers.ReportHelper +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeGenerator +import org.jetbrains.research.testspark.java.JavaPsiClassWrapper +import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper +import org.jetbrains.research.testspark.services.CoverageVisualisationService +import org.jetbrains.research.testspark.services.EditorService +import org.jetbrains.research.testspark.services.PluginSettingsService +import org.jetbrains.research.testspark.services.TestCaseDisplayService +import java.awt.BorderLayout +import java.awt.Color +import java.awt.Dimension +import java.io.File +import java.util.Locale +import javax.swing.Box +import javax.swing.BoxLayout +import javax.swing.JButton +import javax.swing.JCheckBox +import javax.swing.JOptionPane +import javax.swing.JPanel +import javax.swing.JSeparator +import javax.swing.SwingConstants + +@Service(Service.Level.PROJECT) +class JavaTestCaseDisplayService(private val project: Project) : TestCaseDisplayService { + private var report: Report? = null + + private val unselectedTestCases = HashMap() + + private var mainPanel: JPanel = JPanel() + + private val topButtonsPanelFactory = TopButtonsPanelFactory(project, SupportedLanguage.Java) + + private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) + + private var allTestCasePanel: JPanel = JPanel() + + private var scrollPane: JBScrollPane = JBScrollPane( + allTestCasePanel, + JBScrollPane.VERTICAL_SCROLLBAR_ALWAYS, + JBScrollPane.HORIZONTAL_SCROLLBAR_NEVER, + ) + + private var testCasePanels: HashMap = HashMap() + + private var testsSelected: Int = 0 + + /** + * Default color for the editors in the tool window + */ + private var defaultEditorColor: Color? = null + + /** + * Content Manager to be able to add / remove tabs from tool window + */ + private var contentManager: ContentManager? = null + + /** + * Variable to keep reference to the coverage visualisation content + */ + private var content: Content? = null + + var uiContext: UIContext? = null + + init { + allTestCasePanel.layout = BoxLayout(allTestCasePanel, BoxLayout.Y_AXIS) + mainPanel.layout = BorderLayout() + + mainPanel.add(topButtonsPanelFactory.getPanel(), BorderLayout.NORTH) + mainPanel.add(scrollPane, BorderLayout.CENTER) + + applyButton.isOpaque = false + applyButton.isContentAreaFilled = false + mainPanel.add(applyButton, BorderLayout.SOUTH) + + applyButton.addActionListener { applyTests() } + } + + override fun displayTestCases(report: Report, uiContext: UIContext, language: SupportedLanguage) { + this.report = report + this.uiContext = uiContext + + val editor = project.service().editor!! + + allTestCasePanel.removeAll() + testCasePanels.clear() + + addSeparator() + + // TestCasePanelFactories array + val testCasePanelFactories = arrayListOf() + + report.testCaseList.values.forEach { + val testCase = it + val testCasePanel = JPanel() + testCasePanel.layout = BorderLayout() + + // Add a checkbox to select the test + val checkbox = JCheckBox() + checkbox.isSelected = true + checkbox.addItemListener { + // Update the number of selected tests + testsSelected -= (1 - 2 * checkbox.isSelected.compareTo(false)) + + if (checkbox.isSelected) { + ReportHelper.selectTestCase(project, report, unselectedTestCases, testCase.id) + } else { + ReportHelper.unselectTestCase(project, report, unselectedTestCases, testCase.id) + } + + updateUI() + } + testCasePanel.add(checkbox, BorderLayout.WEST) + + val testCasePanelFactory = + TestCasePanelFactory(project, language, testCase, editor, checkbox, uiContext, report) + testCasePanel.add(testCasePanelFactory.getUpperPanel(), BorderLayout.NORTH) + testCasePanel.add(testCasePanelFactory.getMiddlePanel(), BorderLayout.CENTER) + testCasePanel.add(testCasePanelFactory.getBottomPanel(), BorderLayout.SOUTH) + + testCasePanelFactories.add(testCasePanelFactory) + + testCasePanel.add(Box.createRigidArea(Dimension(12, 0)), BorderLayout.EAST) + + // Add panel to parent panel + testCasePanel.maximumSize = Dimension(Short.MAX_VALUE.toInt(), Short.MAX_VALUE.toInt()) + allTestCasePanel.add(testCasePanel) + addSeparator() + testCasePanels[testCase.testName] = testCasePanel + } + + // Update the number of selected tests (all tests are selected by default) + testsSelected = testCasePanels.size + + topButtonsPanelFactory.setTestCasePanelFactoriesArray(testCasePanelFactories) + topButtonsPanelFactory.updateTopLabels() + + createToolWindowTab() + } + + override fun addSeparator() { + allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) + allTestCasePanel.add(JSeparator(SwingConstants.HORIZONTAL)) + allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) + } + + override fun highlightTestCase(name: String) { + val myPanel = testCasePanels[name] ?: return + openToolWindowTab() + scrollToPanel(myPanel) + + val editor = getEditor(name) ?: return + val settingsProjectState = project.service().state + val highlightColor = + JBColor( + PluginSettingsBundle.get("colorName"), + Color( + settingsProjectState.colorRed, + settingsProjectState.colorGreen, + settingsProjectState.colorBlue, + 30, + ), + ) + if (editor.background.equals(highlightColor)) return + defaultEditorColor = editor.background + editor.background = highlightColor + returnOriginalEditorBackground(editor) + } + + override fun openToolWindowTab() { + val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") + contentManager = toolWindowManager!!.contentManager + if (content != null) { + toolWindowManager.show() + toolWindowManager.contentManager.setSelectedContent(content!!) + } + } + + override fun scrollToPanel(myPanel: JPanel) { + var sum = 0 + for (component in allTestCasePanel.components) { + if (component == myPanel) { + break + } else { + sum += component.height + } + } + val scroll = scrollPane.verticalScrollBar + scroll.value = (scroll.minimum + scroll.maximum) * sum / allTestCasePanel.height + } + + override fun removeAllHighlights() { + project.service().editor?.markupModel?.removeAllHighlighters() + } + + override fun returnOriginalEditorBackground(editor: EditorTextField) { + Thread { + Thread.sleep(10000) + editor.background = defaultEditorColor + }.start() + } + + override fun highlightCoveredMutants(names: List) { + names.forEach { + highlightTestCase(it) + } + } + + override fun applyTests() { + // Filter the selected test cases + val selectedTestCasePanels = testCasePanels.filter { (it.value.getComponent(0) as JCheckBox).isSelected } + val selectedTestCases = selectedTestCasePanels.map { it.key } + + // Get the test case components (source code of the tests) + val testCaseComponents = selectedTestCases + .map { getEditor(it)!! } + .map { it.document.text } + + // Descriptor for choosing folders and java files + val descriptor = FileChooserDescriptor(true, true, false, false, false, false) + + // Apply filter with folders and java files with main class + WriteCommandAction.runWriteCommandAction(project) { + descriptor.withFileFilter { file -> + file.isDirectory || ( + file.extension?.lowercase(Locale.getDefault()) == "java" && ( + PsiManager.getInstance(project).findFile(file!!) as PsiJavaFile + ).classes.stream().map { it.name } + .toArray() + .contains( + ( + PsiManager.getInstance(project) + .findFile(file) as PsiJavaFile + ).name.removeSuffix(".java"), + ) + ) + } + } + + val fileChooser = FileChooser.chooseFiles( + descriptor, + project, + LocalFileSystem.getInstance().findFileByPath(project.basePath!!), + ) + + /** + * Cancel button pressed + */ + if (fileChooser.isEmpty()) return + + /** + * Chosen files by user + */ + val chosenFile = fileChooser[0] + + /** + * Virtual file of a final java file + */ + var virtualFile: VirtualFile? = null + + /** + * PsiClass of a final java file + */ + var psiClass: PsiClass? = null + + /** + * PsiJavaFile of a final java file + */ + var psiJavaFile: PsiJavaFile? = null + + if (chosenFile.isDirectory) { + // Input new file data + var className: String + var fileName: String + var filePath: String + // Waiting for correct file name input + while (true) { + val jOptionPane = + JOptionPane.showInputDialog( + null, + PluginLabelsBundle.get("optionPaneMessage"), + PluginLabelsBundle.get("optionPaneTitle"), + JOptionPane.PLAIN_MESSAGE, + null, + null, + null, + ) + + // Cancel button pressed + jOptionPane ?: return + + // Get class name from user + className = jOptionPane as String + + // Set file name and file path + fileName = "${className.split('.')[0]}.java" + filePath = "${chosenFile.path}/$fileName" + + // Check the correctness of a class name + if (!Regex("[A-Z][a-zA-Z0-9]*(.java)?").matches(className)) { + showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage")) + continue + } + + // Check the existence of a file with this name + if (File(filePath).exists()) { + showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage")) + continue + } + break + } + + // Create new file and set services of this file + WriteCommandAction.runWriteCommandAction(project) { + chosenFile.createChildData(null, fileName) + virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!! + psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) + psiClass = PsiElementFactory.getInstance(project).createClass(className.split(".")[0]) + + if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { + psiClass!!.modifierList!!.addAnnotation("RunWith(${uiContext!!.testGenerationOutput.runWith})") + } + + psiJavaFile!!.add(psiClass!!) + } + } else { + // Set services of the chosen file + virtualFile = chosenFile + psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) + psiClass = psiJavaFile!!.classes[ + psiJavaFile!!.classes.stream().map { it.name }.toArray() + .indexOf(psiJavaFile!!.name.removeSuffix(".java")), + ] + } + + // Add tests to the file + WriteCommandAction.runWriteCommandAction(project) { + appendTestsToClass(testCaseComponents, JavaPsiClassWrapper(psiClass!!), psiJavaFile!!) + } + + // Remove the selected test cases from the cache and the tool window UI + removeSelectedTestCases(selectedTestCasePanels) + + // Open the file after adding + FileEditorManager.getInstance(project).openTextEditor( + OpenFileDescriptor(project, virtualFile!!), + true, + ) + } + + override fun showErrorWindow(message: String) { + JOptionPane.showMessageDialog( + null, + message, + PluginLabelsBundle.get("errorWindowTitle"), + JOptionPane.ERROR_MESSAGE, + ) + } + + override fun getEditor(testCaseName: String): EditorTextField? { + val middlePanelComponent = testCasePanels[testCaseName]?.getComponent(2) ?: return null + val middlePanel = middlePanelComponent as JPanel + return (middlePanel.getComponent(1) as JBScrollPane).viewport.view as EditorTextField + } + + override fun appendTestsToClass( + testCaseComponents: List, + selectedClass: PsiClassWrapper, + outputFile: PsiFile, + ) { + // block document + PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument( + PsiDocumentManager.getInstance(project).getDocument(outputFile as PsiJavaFile)!!, + ) + + // insert tests to a code + testCaseComponents.reversed().forEach { + val testMethodCode = + JavaTestClassCodeAnalyzer.extractFirstTestMethodCode( + JavaTestClassCodeGenerator.formatCode( + project, + it.replace("\r\n", "\n") + .replace("verifyException(", "// verifyException("), + uiContext!!.testGenerationOutput, + ), + ) + // Fix Windows line separators + .replace("\r\n", "\n") + + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + selectedClass.rBrace!!, + testMethodCode, + ) + } + + // insert other info to a code + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + selectedClass.rBrace!!, + uiContext!!.testGenerationOutput.otherInfo + "\n", + ) + + // insert imports to a code + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + outputFile.importList?.startOffset ?: outputFile.packageStatement?.startOffset ?: 0, + uiContext!!.testGenerationOutput.importsCode.joinToString("\n") + "\n\n", + ) + + // insert package to a code + outputFile.packageStatement ?: PsiDocumentManager.getInstance(project).getDocument(outputFile)!! + .insertString( + 0, + if (uiContext!!.testGenerationOutput.packageName.isEmpty()) { + "" + } else { + "package ${uiContext!!.testGenerationOutput.packageName};\n\n" + }, + ) + } + + override fun updateEditorForFileUrl(fileUrl: String) { + val documentManager = FileDocumentManager.getInstance() + // https://intellij-support.jetbrains.com/hc/en-us/community/posts/360004480599/comments/360000703299 + FileEditorManager.getInstance(project).selectedEditors.map { it as TextEditor }.map { it.editor }.map { + val currentFile = documentManager.getFile(it.document) + if (currentFile != null) { + if (currentFile.presentableUrl == fileUrl) { + project.service().editor = it + } + } + } + } + + override fun createToolWindowTab() { + // Remove generated tests tab from content manager if necessary + val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") + contentManager = toolWindowManager!!.contentManager + if (content != null) { + contentManager!!.removeContent(content!!, true) + } + + // If there is no generated tests tab, make it + val contentFactory: ContentFactory = ContentFactory.getInstance() + content = contentFactory.createContent( + mainPanel, + PluginLabelsBundle.get("generatedTests"), + true, + ) + contentManager!!.addContent(content!!) + + // Focus on generated tests tab and open toolWindow if not opened already + contentManager!!.setSelectedContent(content!!) + toolWindowManager.show() + } + + override fun closeToolWindow() { + contentManager?.removeContent(content!!, true) + ToolWindowManager.getInstance(project).getToolWindow("TestSpark")?.hide() + val coverageVisualisationService = project.service() + coverageVisualisationService.closeToolWindowTab() + } + + override fun removeSelectedTestCases(selectedTestCasePanels: Map) { + selectedTestCasePanels.forEach { removeTestCase(it.key) } + removeAllHighlights() + closeToolWindow() + } + + override fun clear() { + // Remove the tests + val testCasePanelsToRemove = testCasePanels.toMap() + removeSelectedTestCases(testCasePanelsToRemove) + + topButtonsPanelFactory.clear() + } + + override fun removeTestCase(testCaseName: String) { + // Update the number of selected test cases if necessary + if ((testCasePanels[testCaseName]!!.getComponent(0) as JCheckBox).isSelected) { + testsSelected-- + } + + // Remove the test panel from the UI + allTestCasePanel.remove(testCasePanels[testCaseName]) + + // Remove the test panel + testCasePanels.remove(testCaseName) + } + + override fun updateUI() { + // Update the UI of the tool window tab + allTestCasePanel.updateUI() + + topButtonsPanelFactory.updateTopLabels() + + // If no more tests are remaining, close the tool window + if (testCasePanels.size == 0) closeToolWindow() + } + + override fun getTestCasePanels() = testCasePanels + + override fun getTestsSelected() = testsSelected + + override fun setTestsSelected(testsSelected: Int) { + this.testsSelected = testsSelected + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt new file mode 100644 index 000000000..a77edd16d --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt @@ -0,0 +1,554 @@ +package org.jetbrains.research.testspark.services.kotlin + +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.openapi.components.Service +import com.intellij.openapi.components.service +import com.intellij.openapi.fileChooser.FileChooser +import com.intellij.openapi.fileChooser.FileChooserDescriptor +import com.intellij.openapi.fileEditor.FileDocumentManager +import com.intellij.openapi.fileEditor.FileEditorManager +import com.intellij.openapi.fileEditor.OpenFileDescriptor +import com.intellij.openapi.fileEditor.TextEditor +import com.intellij.openapi.project.Project +import com.intellij.openapi.vfs.LocalFileSystem +import com.intellij.openapi.vfs.VirtualFile +import com.intellij.openapi.vfs.VirtualFileManager +import com.intellij.openapi.wm.ToolWindowManager +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiFile +import com.intellij.psi.PsiJavaFile +import com.intellij.psi.PsiManager +import com.intellij.refactoring.suggested.endOffset +import com.intellij.refactoring.suggested.startOffset +import com.intellij.ui.EditorTextField +import com.intellij.ui.JBColor +import com.intellij.ui.components.JBScrollPane +import com.intellij.ui.content.Content +import com.intellij.ui.content.ContentFactory +import com.intellij.ui.content.ContentManager +import com.intellij.util.containers.stream +import org.jetbrains.kotlin.psi.KtClass +import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtPsiFactory +import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle +import org.jetbrains.research.testspark.core.data.Report +import org.jetbrains.research.testspark.core.data.TestCase +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.data.UIContext +import org.jetbrains.research.testspark.display.TestCasePanelFactory +import org.jetbrains.research.testspark.display.TopButtonsPanelFactory +import org.jetbrains.research.testspark.helpers.ReportHelper +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeGenerator +import org.jetbrains.research.testspark.kotlin.KotlinPsiClassWrapper +import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper +import org.jetbrains.research.testspark.services.CoverageVisualisationService +import org.jetbrains.research.testspark.services.EditorService +import org.jetbrains.research.testspark.services.PluginSettingsService +import org.jetbrains.research.testspark.services.TestCaseDisplayService +import java.awt.BorderLayout +import java.awt.Color +import java.awt.Dimension +import java.io.File +import java.util.Locale +import javax.swing.Box +import javax.swing.BoxLayout +import javax.swing.JButton +import javax.swing.JCheckBox +import javax.swing.JOptionPane +import javax.swing.JPanel +import javax.swing.JSeparator +import javax.swing.SwingConstants + +@Service(Service.Level.PROJECT) +class KotlinTestCaseDisplayService(private val project: Project) : TestCaseDisplayService { + private var report: Report? = null + + private val unselectedTestCases = HashMap() + + private var mainPanel: JPanel = JPanel() + + private val topButtonsPanelFactory = TopButtonsPanelFactory(project, SupportedLanguage.Kotlin) + + private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) + + private var allTestCasePanel: JPanel = JPanel() + + private var scrollPane: JBScrollPane = JBScrollPane( + allTestCasePanel, + JBScrollPane.VERTICAL_SCROLLBAR_ALWAYS, + JBScrollPane.HORIZONTAL_SCROLLBAR_NEVER, + ) + + private var testCasePanels: HashMap = HashMap() + + private var testsSelected: Int = 0 + + /** + * Default color for the editors in the tool window + */ + private var defaultEditorColor: Color? = null + + /** + * Content Manager to be able to add / remove tabs from tool window + */ + private var contentManager: ContentManager? = null + + /** + * Variable to keep reference to the coverage visualisation content + */ + private var content: Content? = null + + var uiContext: UIContext? = null + + init { + allTestCasePanel.layout = BoxLayout(allTestCasePanel, BoxLayout.Y_AXIS) + mainPanel.layout = BorderLayout() + + mainPanel.add(topButtonsPanelFactory.getPanel(), BorderLayout.NORTH) + mainPanel.add(scrollPane, BorderLayout.CENTER) + + applyButton.isOpaque = false + applyButton.isContentAreaFilled = false + mainPanel.add(applyButton, BorderLayout.SOUTH) + + applyButton.addActionListener { applyTests() } + } + + override fun displayTestCases(report: Report, uiContext: UIContext, language: SupportedLanguage) { + this.report = report + this.uiContext = uiContext + + val editor = project.service().editor!! + + allTestCasePanel.removeAll() + testCasePanels.clear() + + addSeparator() + + // TestCasePanelFactories array + val testCasePanelFactories = arrayListOf() + + report.testCaseList.values.forEach { + val testCase = it + val testCasePanel = JPanel() + testCasePanel.layout = BorderLayout() + + // Add a checkbox to select the test + val checkbox = JCheckBox() + checkbox.isSelected = true + checkbox.addItemListener { + // Update the number of selected tests + testsSelected -= (1 - 2 * checkbox.isSelected.compareTo(false)) + + if (checkbox.isSelected) { + ReportHelper.selectTestCase(project, report, unselectedTestCases, testCase.id) + } else { + ReportHelper.unselectTestCase(project, report, unselectedTestCases, testCase.id) + } + + updateUI() + } + testCasePanel.add(checkbox, BorderLayout.WEST) + + val testCasePanelFactory = + TestCasePanelFactory(project, language, testCase, editor, checkbox, uiContext, report) + testCasePanel.add(testCasePanelFactory.getUpperPanel(), BorderLayout.NORTH) + testCasePanel.add(testCasePanelFactory.getMiddlePanel(), BorderLayout.CENTER) + testCasePanel.add(testCasePanelFactory.getBottomPanel(), BorderLayout.SOUTH) + + testCasePanelFactories.add(testCasePanelFactory) + + testCasePanel.add(Box.createRigidArea(Dimension(12, 0)), BorderLayout.EAST) + + // Add panel to parent panel + testCasePanel.maximumSize = Dimension(Short.MAX_VALUE.toInt(), Short.MAX_VALUE.toInt()) + allTestCasePanel.add(testCasePanel) + addSeparator() + testCasePanels[testCase.testName] = testCasePanel + } + + // Update the number of selected tests (all tests are selected by default) + testsSelected = testCasePanels.size + + topButtonsPanelFactory.setTestCasePanelFactoriesArray(testCasePanelFactories) + topButtonsPanelFactory.updateTopLabels() + + createToolWindowTab() + } + + override fun addSeparator() { + allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) + allTestCasePanel.add(JSeparator(SwingConstants.HORIZONTAL)) + allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) + } + + override fun highlightTestCase(name: String) { + val myPanel = testCasePanels[name] ?: return + openToolWindowTab() + scrollToPanel(myPanel) + + val editor = getEditor(name) ?: return + val settingsProjectState = project.service().state + val highlightColor = + JBColor( + PluginSettingsBundle.get("colorName"), + Color( + settingsProjectState.colorRed, + settingsProjectState.colorGreen, + settingsProjectState.colorBlue, + 30, + ), + ) + if (editor.background.equals(highlightColor)) return + defaultEditorColor = editor.background + editor.background = highlightColor + returnOriginalEditorBackground(editor) + } + + override fun openToolWindowTab() { + val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") + contentManager = toolWindowManager!!.contentManager + if (content != null) { + toolWindowManager.show() + toolWindowManager.contentManager.setSelectedContent(content!!) + } + } + + override fun scrollToPanel(myPanel: JPanel) { + var sum = 0 + for (component in allTestCasePanel.components) { + if (component == myPanel) { + break + } else { + sum += component.height + } + } + val scroll = scrollPane.verticalScrollBar + scroll.value = (scroll.minimum + scroll.maximum) * sum / allTestCasePanel.height + } + + override fun removeAllHighlights() { + project.service().editor?.markupModel?.removeAllHighlighters() + } + + override fun returnOriginalEditorBackground(editor: EditorTextField) { + Thread { + Thread.sleep(10000) + editor.background = defaultEditorColor + }.start() + } + + override fun highlightCoveredMutants(names: List) { + names.forEach { + highlightTestCase(it) + } + } + + override fun applyTests() { + // Filter the selected test cases + val selectedTestCasePanels = testCasePanels.filter { (it.value.getComponent(0) as JCheckBox).isSelected } + val selectedTestCases = selectedTestCasePanels.map { it.key } + + // Get the test case components (source code of the tests) + val testCaseComponents = selectedTestCases + .map { getEditor(it)!! } + .map { it.document.text } + + // Descriptor for choosing folders and java files + val descriptor = FileChooserDescriptor(true, true, false, false, false, false) + + // Apply filter with folders and java files with main class + WriteCommandAction.runWriteCommandAction(project) { + descriptor.withFileFilter { file -> + file.isDirectory || ( + file.extension?.lowercase(Locale.getDefault()) == "kotlin" && ( + PsiManager.getInstance(project).findFile(file!!) as KtFile + ).classes.stream().map { it.name } + .toArray() + .contains( + ( + PsiManager.getInstance(project) + .findFile(file) as PsiJavaFile + ).name.removeSuffix(".kt"), + ) + ) + } + } + + val fileChooser = FileChooser.chooseFiles( + descriptor, + project, + LocalFileSystem.getInstance().findFileByPath(project.basePath!!), + ) + + /** + * Cancel button pressed + */ + if (fileChooser.isEmpty()) return + + /** + * Chosen files by user + */ + val chosenFile = fileChooser[0] + + /** + * Virtual file of a final java file + */ + var virtualFile: VirtualFile? = null + + /** + * PsiClass of a final java file + */ + var ktClass: KtClass? = null + + /** + * PsiJavaFile of a final java file + */ + var psiKotlinFile: KtFile? = null + + if (chosenFile.isDirectory) { + // Input new file data + var className: String + var fileName: String + var filePath: String + // Waiting for correct file name input + while (true) { + val jOptionPane = + JOptionPane.showInputDialog( + null, + PluginLabelsBundle.get("optionPaneMessage"), + PluginLabelsBundle.get("optionPaneTitle"), + JOptionPane.PLAIN_MESSAGE, + null, + null, + null, + ) + + // Cancel button pressed + jOptionPane ?: return + + // Get class name from user + className = jOptionPane as String + + // Set file name and file path + fileName = "${className.split('.')[0]}.kt" + filePath = "${chosenFile.path}/$fileName" + + // Check the correctness of a class name + if (!Regex("[A-Z][a-zA-Z0-9]*(.kt)?").matches(className)) { + showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage")) + continue + } + + // Check the existence of a file with this name + if (File(filePath).exists()) { + showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage")) + continue + } + break + } + + // Create new file and set services of this file + WriteCommandAction.runWriteCommandAction(project) { + chosenFile.createChildData(null, fileName) + virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!! + psiKotlinFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as KtFile) + + val ktPsiFactory = KtPsiFactory(project) + ktClass = ktPsiFactory.createClass("class ${className.split(".")[0]} {}") + + if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { + val annotationEntry = + ktPsiFactory.createAnnotationEntry("@RunWith(${uiContext!!.testGenerationOutput.runWith})") + ktClass!!.addBefore(annotationEntry, ktClass!!.body) + } + + psiKotlinFile!!.add(ktClass!!) + } + } else { + // Set services of the chosen file + virtualFile = chosenFile + psiKotlinFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as KtFile) + val classNameNoSuffix = psiKotlinFile!!.name.removeSuffix(".kt") + ktClass = psiKotlinFile?.declarations?.filterIsInstance()?.find { it.name == classNameNoSuffix } + } + + // Add tests to the file + WriteCommandAction.runWriteCommandAction(project) { + appendTestsToClass(testCaseComponents, KotlinPsiClassWrapper(ktClass as KtClass), psiKotlinFile!!) + } + + // Remove the selected test cases from the cache and the tool window UI + removeSelectedTestCases(selectedTestCasePanels) + + // Open the file after adding + FileEditorManager.getInstance(project).openTextEditor( + OpenFileDescriptor(project, virtualFile!!), + true, + ) + } + + override fun showErrorWindow(message: String) { + JOptionPane.showMessageDialog( + null, + message, + PluginLabelsBundle.get("errorWindowTitle"), + JOptionPane.ERROR_MESSAGE, + ) + } + + override fun getEditor(testCaseName: String): EditorTextField? { + val middlePanelComponent = testCasePanels[testCaseName]?.getComponent(2) ?: return null + val middlePanel = middlePanelComponent as JPanel + return (middlePanel.getComponent(1) as JBScrollPane).viewport.view as EditorTextField + } + + override fun appendTestsToClass( + testCaseComponents: List, + selectedClass: PsiClassWrapper, + outputFile: PsiFile, + ) { + // block document + PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument( + PsiDocumentManager.getInstance(project).getDocument(outputFile as KtFile)!!, + ) + + // insert tests to a code + testCaseComponents.reversed().forEach { + val testMethodCode = + KotlinTestClassCodeAnalyzer.extractFirstTestMethodCode( + KotlinTestClassCodeGenerator.formatCode( + project, + it.replace("\r\n", "\n") + .replace("verifyException(", "// verifyException("), + uiContext!!.testGenerationOutput, + ), + ) + // Fix Windows line separators + .replace("\r\n", "\n") + + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + selectedClass.rBrace!!, + testMethodCode, + ) + } + + // insert other info to a code + PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( + selectedClass.rBrace!!, + uiContext!!.testGenerationOutput.otherInfo + "\n", + ) + + // Create the imports string + val importsString = uiContext!!.testGenerationOutput.importsCode.joinToString("\n") + "\n\n" + + // Find the insertion offset + val insertionOffset = outputFile.importList?.startOffset + ?: outputFile.packageDirective?.endOffset + ?: 0 + + // Insert the imports into the document + PsiDocumentManager.getInstance(project).getDocument(outputFile)?.let { document -> + document.insertString(insertionOffset, importsString) + PsiDocumentManager.getInstance(project).commitDocument(document) + } + + val packageName = uiContext!!.testGenerationOutput.packageName + val packageStatement = if (packageName.isEmpty()) "" else "package $packageName\n\n" + + // Insert the package statement at the beginning of the document + PsiDocumentManager.getInstance(project).getDocument(outputFile)?.let { document -> + document.insertString(0, packageStatement) + PsiDocumentManager.getInstance(project).commitDocument(document) + } + } + + override fun updateEditorForFileUrl(fileUrl: String) { + val documentManager = FileDocumentManager.getInstance() + // https://intellij-support.jetbrains.com/hc/en-us/community/posts/360004480599/comments/360000703299 + FileEditorManager.getInstance(project).selectedEditors.map { it as TextEditor }.map { it.editor }.map { + val currentFile = documentManager.getFile(it.document) + if (currentFile != null) { + if (currentFile.presentableUrl == fileUrl) { + project.service().editor = it + } + } + } + } + + override fun createToolWindowTab() { + // Remove generated tests tab from content manager if necessary + val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") + contentManager = toolWindowManager!!.contentManager + if (content != null) { + contentManager!!.removeContent(content!!, true) + } + + // If there is no generated tests tab, make it + val contentFactory: ContentFactory = ContentFactory.getInstance() + content = contentFactory.createContent( + mainPanel, + PluginLabelsBundle.get("generatedTests"), + true, + ) + contentManager!!.addContent(content!!) + + // Focus on generated tests tab and open toolWindow if not opened already + contentManager!!.setSelectedContent(content!!) + toolWindowManager.show() + } + + override fun closeToolWindow() { + contentManager?.removeContent(content!!, true) + ToolWindowManager.getInstance(project).getToolWindow("TestSpark")?.hide() + val coverageVisualisationService = project.service() + coverageVisualisationService.closeToolWindowTab() + } + + override fun removeSelectedTestCases(selectedTestCasePanels: Map) { + selectedTestCasePanels.forEach { removeTestCase(it.key) } + removeAllHighlights() + closeToolWindow() + } + + override fun clear() { + // Remove the tests + val testCasePanelsToRemove = testCasePanels.toMap() + removeSelectedTestCases(testCasePanelsToRemove) + + topButtonsPanelFactory.clear() + } + + override fun removeTestCase(testCaseName: String) { + // Update the number of selected test cases if necessary + if ((testCasePanels[testCaseName]!!.getComponent(0) as JCheckBox).isSelected) { + testsSelected-- + } + + // Remove the test panel from the UI + allTestCasePanel.remove(testCasePanels[testCaseName]) + + // Remove the test panel + testCasePanels.remove(testCaseName) + } + + override fun updateUI() { + // Update the UI of the tool window tab + allTestCasePanel.updateUI() + + topButtonsPanelFactory.updateTopLabels() + + // If no more tests are remaining, close the tool window + if (testCasePanels.size == 0) closeToolWindow() + } + + override fun getTestCasePanels() = testCasePanels + + override fun getTestsSelected() = testsSelected + + override fun setTestsSelected(testsSelected: Int) { + this.testsSelected = testsSelected + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt index 89e480e83..6c3d77a05 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt @@ -45,7 +45,7 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent { // Models private var modelSelector = ComboBox(arrayOf("")) - private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName)) + private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName)) // Default LLM Requests private var defaultLLMRequestsSeparator = diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt index 5f792b328..2b0ff5769 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt @@ -42,6 +42,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab settingsComponent!!.llmPlatforms[index].token = llmSettingsState.grazieToken settingsComponent!!.llmPlatforms[index].model = llmSettingsState.grazieModel } + if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + settingsComponent!!.llmPlatforms[index].token = llmSettingsState.huggingFaceToken + settingsComponent!!.llmPlatforms[index].model = llmSettingsState.huggingFaceModel + } } settingsComponent!!.currentLLMPlatformName = llmSettingsState.currentLLMPlatformName settingsComponent!!.maxLLMRequest = llmSettingsState.maxLLMRequest @@ -81,6 +85,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab modified = modified or (settingsComponent!!.llmPlatforms[index].token != llmSettingsState.grazieToken) modified = modified or (settingsComponent!!.llmPlatforms[index].model != llmSettingsState.grazieModel) } + if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + modified = modified or (settingsComponent!!.llmPlatforms[index].token != llmSettingsState.huggingFaceToken) + modified = modified or (settingsComponent!!.llmPlatforms[index].model != llmSettingsState.huggingFaceModel) + } } modified = modified or (settingsComponent!!.currentLLMPlatformName != llmSettingsState.currentLLMPlatformName) modified = modified or (settingsComponent!!.maxLLMRequest != llmSettingsState.maxLLMRequest) @@ -138,6 +146,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab llmSettingsState.grazieToken = settingsComponent!!.llmPlatforms[index].token llmSettingsState.grazieModel = settingsComponent!!.llmPlatforms[index].model } + if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + llmSettingsState.huggingFaceToken = settingsComponent!!.llmPlatforms[index].token + llmSettingsState.huggingFaceModel = settingsComponent!!.llmPlatforms[index].model + } } llmSettingsState.currentLLMPlatformName = settingsComponent!!.currentLLMPlatformName llmSettingsState.maxLLMRequest = settingsComponent!!.maxLLMRequest diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt index 3ce378707..590ec3c1d 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt @@ -15,6 +15,9 @@ data class LLMSettingsState( var grazieName: String = DefaultLLMSettingsState.grazieName, var grazieToken: String = DefaultLLMSettingsState.grazieToken, var grazieModel: String = DefaultLLMSettingsState.grazieModel, + var huggingFaceName: String = DefaultLLMSettingsState.huggingFaceName, + var huggingFaceToken: String = DefaultLLMSettingsState.huggingFaceToken, + var huggingFaceModel: String = DefaultLLMSettingsState.huggingFaceModel, var currentLLMPlatformName: String = DefaultLLMSettingsState.currentLLMPlatformName, var maxLLMRequest: Int = DefaultLLMSettingsState.maxLLMRequest, var maxInputParamsDepth: Int = DefaultLLMSettingsState.maxInputParamsDepth, @@ -45,6 +48,9 @@ data class LLMSettingsState( val grazieName: String = LLMDefaultsBundle.get("grazieName") val grazieToken: String = LLMDefaultsBundle.get("grazieToken") val grazieModel: String = LLMDefaultsBundle.get("grazieModel") + val huggingFaceName: String = LLMDefaultsBundle.get("huggingFaceName") + val huggingFaceToken: String = LLMDefaultsBundle.get("huggingFaceToken") + val huggingFaceModel: String = LLMDefaultsBundle.get("huggingFaceModel") var currentLLMPlatformName: String = LLMDefaultsBundle.get("openAIName") val maxLLMRequest: Int = LLMDefaultsBundle.get("maxLLMRequest").toInt() val maxInputParamsDepth: Int = LLMDefaultsBundle.get("maxInputParamsDepth").toInt() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt index 0cd1b073a..c4310ba61 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt @@ -2,7 +2,7 @@ package org.jetbrains.research.testspark.tools import com.intellij.openapi.application.PathManager import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.dependencies.JavaTestCompilationDependencies +import org.jetbrains.research.testspark.core.test.data.dependencies.TestCompilationDependencies import java.io.File /** @@ -16,7 +16,7 @@ class LibraryPathsProvider { private val sep = File.separatorChar private val libPrefix = "${PathManager.getPluginsPath()}${sep}TestSpark${sep}lib$sep" - fun getTestCompilationLibraryPaths() = JavaTestCompilationDependencies.getJarDescriptors().map { descriptor -> + fun getTestCompilationLibraryPaths() = TestCompilationDependencies.getJarDescriptors().map { descriptor -> "$libPrefix${sep}${descriptor.name}" } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt index aa5b694b7..30ed0ba6b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt @@ -6,12 +6,12 @@ import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.progress.Task import com.intellij.openapi.project.Project -import com.intellij.openapi.roots.ProjectFileIndex import com.intellij.openapi.roots.ProjectRootManager import com.intellij.openapi.util.io.FileUtilRt import org.jetbrains.research.testspark.actions.controllers.TestGenerationController import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.utils.DataFilesUtil import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.ProjectContext @@ -22,6 +22,8 @@ import org.jetbrains.research.testspark.services.CoverageVisualisationService import org.jetbrains.research.testspark.services.EditorService import org.jetbrains.research.testspark.services.TestCaseDisplayService import org.jetbrains.research.testspark.services.TestsExecutionResultService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService import org.jetbrains.research.testspark.tools.template.generation.ProcessManager import java.util.UUID @@ -29,7 +31,7 @@ import java.util.UUID * Pipeline class represents a pipeline for generating tests in a project. * * @param project the project in which the pipeline is executed. - * @param psiHelper The PsiHelper in the context of witch the pipeline is executed. + * @param psiHelper The PsiHelper in the context of which the pipeline is executed. * @param caretOffset the offset of the caret position in the PSI file. * @param fileUrl the URL of the file being processed, if applicable. * @param packageName the package name of the file being processed. @@ -47,7 +49,7 @@ class Pipeline( init { - val cutPsiClass = psiHelper.getSurroundingClass(caretOffset)!! + val cutPsiClass = psiHelper.getSurroundingClass(caretOffset) // get generated test path val testResultDirectory = "${FileUtilRt.getTempDirectory()}${ToolUtils.sep}testSparkResults${ToolUtils.sep}" @@ -57,10 +59,8 @@ class Pipeline( ApplicationManager.getApplication().runWriteAction { projectContext.projectClassPath = ProjectRootManager.getInstance(project).contentRoots.first().path projectContext.fileUrlAsString = fileUrl - projectContext.classFQN = cutPsiClass.qualifiedName - // TODO probably can be made easier - projectContext.cutModule = - ProjectFileIndex.getInstance(project).getModuleForFile(cutPsiClass.virtualFile)!! + cutPsiClass?.let { projectContext.classFQN = it.qualifiedName } + projectContext.cutModule = psiHelper.getModuleFromPsiFile() } generatedTestsData.resultPath = ToolUtils.getResultPath(id, testResultDirectory) @@ -108,14 +108,13 @@ class Pipeline( override fun onFinished() { super.onFinished() testGenerationController.finished() - uiContext?.let { - project.service() - .updateEditorForFileUrl(it.testGenerationOutput.fileUrl) - - if (project.service().editor != null) { - val report = it.testGenerationOutput.testGenerationResultList[0]!! - project.service().displayTestCases(report, it, psiHelper.language) - project.service().showCoverage(report) + when (psiHelper.language) { + SupportedLanguage.Java -> uiContext?.let { + displayTestCase(it) + } + + SupportedLanguage.Kotlin -> uiContext?.let { + displayTestCase(it) } } } @@ -124,8 +123,22 @@ class Pipeline( private fun clear(project: Project) { // should be removed totally! testGenerationController.errorMonitor.clear() - project.service().clear() + when (psiHelper.language) { + SupportedLanguage.Java -> project.service().clear() + SupportedLanguage.Kotlin -> project.service().clear() + } + project.service().clear() project.service().clear() } + + private inline fun displayTestCase(ctx: UIContext) { + project.service().updateEditorForFileUrl(ctx.testGenerationOutput.fileUrl) + + if (project.service().editor != null) { + val report = ctx.testGenerationOutput.testGenerationResultList[0]!! + project.service().displayTestCases(report, ctx, psiHelper.language) + project.service().showCoverage(report) + } + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt new file mode 100644 index 000000000..ea0c0bc2e --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestBodyPrinterFactory.kt @@ -0,0 +1,17 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.java.JavaTestBodyPrinter +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestBodyPrinter + +class TestBodyPrinterFactory { + companion object { + fun create(language: SupportedLanguage): TestBodyPrinter { + return when (language) { + SupportedLanguage.Kotlin -> KotlinTestBodyPrinter() + SupportedLanguage.Java -> JavaTestBodyPrinter() + } + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt new file mode 100644 index 000000000..1b73c380c --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeAnalyzerFactory.kt @@ -0,0 +1,21 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.helpers.TestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeAnalyzer + +object TestClassCodeAnalyzerFactory { + /** + * Creates an instance of TestClassCodeAnalyzer for the specified language. + * + * @param language the programming language for which to create the analyzer + * @return an instance of TestClassCodeAnalyzer + */ + fun create(language: SupportedLanguage): TestClassCodeAnalyzer { + return when (language) { + SupportedLanguage.Kotlin -> KotlinTestClassCodeAnalyzer + SupportedLanguage.Java -> JavaTestClassCodeAnalyzer + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt new file mode 100644 index 000000000..56151e26e --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestClassCodeGeneratorFactory.kt @@ -0,0 +1,21 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.helpers.TestClassCodeGenerator +import org.jetbrains.research.testspark.helpers.java.JavaTestClassCodeGenerator +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeGenerator + +object TestClassCodeGeneratorFactory { + /** + * Creates an instance of TestClassCodeGenerator for the specified language. + * + * @param language the programming language for which to create the generator + * @return an instance of TestClassCodeGenerator + */ + fun create(language: SupportedLanguage): TestClassCodeGenerator { + return when (language) { + SupportedLanguage.Kotlin -> KotlinTestClassCodeGenerator + SupportedLanguage.Java -> JavaTestClassCodeGenerator + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt index 8680370bd..84b512bb5 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt @@ -3,20 +3,31 @@ package org.jetbrains.research.testspark.tools import com.intellij.openapi.project.Project import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.test.java.JavaTestCompiler +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestCompiler class TestCompilerFactory { companion object { - fun createJavacTestCompiler( + fun create( project: Project, junitVersion: JUnitVersion, + language: SupportedLanguage, javaHomeDirectory: String? = null, ): TestCompiler { - val javaHomePath = javaHomeDirectory ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + val javaSDKHomePath = + javaHomeDirectory ?: ProjectRootManager.getInstance(project).projectSdk?.homeDirectory?.path + ?: throw RuntimeException("Java SDK not configured for the project.") + val libraryPaths = LibraryPathsProvider.getTestCompilationLibraryPaths() val junitLibraryPaths = LibraryPathsProvider.getJUnitLibraryPaths(junitVersion) - return TestCompiler(javaHomePath, libraryPaths, junitLibraryPaths) + // TODO add the warning window that for Java we always need the javaHomeDirectoryPath + return when (language) { + SupportedLanguage.Java -> JavaTestCompiler(libraryPaths, junitLibraryPaths, javaSDKHomePath) + SupportedLanguage.Kotlin -> KotlinTestCompiler(libraryPaths, junitLibraryPaths) + } } } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt index e0a4150b4..d35589357 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt @@ -8,6 +8,7 @@ import com.intellij.openapi.roots.CompilerModuleExtension import com.intellij.openapi.roots.ModuleRootManager import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.core.data.TestCase +import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.utils.CommandLineRunner import org.jetbrains.research.testspark.core.utils.DataFilesUtil @@ -25,16 +26,20 @@ class TestProcessor( val project: Project, givenProjectSDKPath: Path? = null, ) : TestsPersistentStorage { - private val javaHomeDirectory = givenProjectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + private val homeDirectory = + givenProjectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path private val log = Logger.getInstance(this::class.java) private val llmSettingsState: LLMSettingsState get() = project.getService(LLMSettingsService::class.java).state - val testCompiler = TestCompilerFactory.createJavacTestCompiler(project, llmSettingsState.junitVersion, javaHomeDirectory) - - override fun saveGeneratedTest(packageString: String, code: String, resultPath: String, testFileName: String): String { + override fun saveGeneratedTest( + packageString: String, + code: String, + resultPath: String, + testFileName: String, + ): String { // Generate the final path for the generated tests var generatedTestPath = "$resultPath${File.separatorChar}" packageString.split(".").forEach { directory -> @@ -69,14 +74,10 @@ class TestProcessor( generatedTestPackage: String, resultPath: String, projectContext: ProjectContext, + testCompiler: TestCompiler, ): String { // find the proper javac - val javaRunner = File(javaHomeDirectory).walk() - .filter { - val isJavaName = if (DataFilesUtil.isWindows()) it.name.equals("java.exe") else it.name.equals("java") - isJavaName && it.isFile - } - .first() + val javaRunner = findJavaCompilerInDirectory(homeDirectory) // JaCoCo libs val jacocoAgentLibraryPath = "\"${LibraryPathsProvider.getJacocoAgentLibraryPath()}\"" val jacocoCLILibraryPath = "\"${LibraryPathsProvider.getJacocoCliLibraryPath()}\"" @@ -90,13 +91,21 @@ class TestProcessor( val junitVersion = llmSettingsState.junitVersion.version // run the test method with jacoco agent + log.info("[TestProcessor] Executing $name") val junitRunnerLibraryPath = LibraryPathsProvider.getJUnitRunnerLibraryPath() + // classFQN will be null for the top level function + val javaAgentFlag = + if (projectContext.classFQN != null) { + "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false,includes=${projectContext.classFQN}" + } else { + "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false" + } val testExecutionError = CommandLineRunner.run( arrayListOf( javaRunner.absolutePath, - "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false,includes=${projectContext.classFQN}", + javaAgentFlag, "-cp", - "\"${testCompiler.getPath(projectBuildPath)}${DataFilesUtil.classpathSeparator}${junitRunnerLibraryPath}${DataFilesUtil.classpathSeparator}$resultPath\"", + "\"${testCompiler.getClassPaths(projectBuildPath)}${DataFilesUtil.classpathSeparator}${junitRunnerLibraryPath}${DataFilesUtil.classpathSeparator}$resultPath\"", "org.jetbrains.research.SingleJUnitTestRunner$junitVersion", name, ), @@ -148,9 +157,10 @@ class TestProcessor( testId: Int, testName: String, testCode: String, - packageLine: String, + packageName: String, resultPath: String, projectContext: ProjectContext, + testCompiler: TestCompiler, ): TestCase { // get buildPath var buildPath: String = ProjectRootManager.getInstance(project).contentRoots.first().path @@ -161,7 +171,7 @@ class TestProcessor( // save new test to file val generatedTestPath: String = saveGeneratedTest( - packageLine, + packageName, testCode, resultPath, fileName, @@ -179,9 +189,10 @@ class TestProcessor( dataFileName, testName, buildPath, - packageLine, + packageName, resultPath, projectContext, + testCompiler, ) if (!File("$dataFileName.xml").exists()) { @@ -230,7 +241,8 @@ class TestProcessor( frames.removeFirst() frames.forEach { frame -> - if (frame.contains(projectContext.classFQN!!)) { + // classFQN will be null for the top level function + if (projectContext.classFQN != null && frame.contains(projectContext.classFQN!!)) { val coveredLineNumber = frame.split(":")[1].replace(")", "").toIntOrNull() if (coveredLineNumber != null) { result.add(coveredLineNumber) @@ -274,7 +286,8 @@ class TestProcessor( children("counter") {} } children("sourcefile") { - isCorrectSourceFile = this.attributes.getValue("name") == projectContext.fileUrlAsString!!.split(File.separatorChar).last() + isCorrectSourceFile = + this.attributes.getValue("name") == projectContext.fileUrlAsString!!.split(File.separatorChar).last() children("line") { if (isCorrectSourceFile && this.attributes.getValue("mi") == "0") { setOfLines.add(this.attributes.getValue("nr").toInt()) @@ -295,4 +308,18 @@ class TestProcessor( return TestCase(testCaseId, testCaseName, testCaseCode, setOfLines) } + + /** + * Finds 'javac' compiler (both on Unix & Windows) + * starting from the provided directory. + */ + private fun findJavaCompilerInDirectory(homeDirectory: String): File { + return File(homeDirectory).walk() + .filter { + val isJavaName = + if (DataFilesUtil.isWindows()) it.name.equals("java.exe") else it.name.equals("java") + isJavaName && it.isFile + } + .first() + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestSuiteParserFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestSuiteParserFactory.kt new file mode 100644 index 000000000..3c4ca5637 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestSuiteParserFactory.kt @@ -0,0 +1,31 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestBodyPrinter +import org.jetbrains.research.testspark.core.test.TestSuiteParser +import org.jetbrains.research.testspark.core.test.java.JavaJUnitTestSuiteParser +import org.jetbrains.research.testspark.core.test.kotlin.KotlinJUnitTestSuiteParser + +class TestSuiteParserFactory { + companion object { + fun createJUnitTestSuiteParser( + jUnitVersion: JUnitVersion, + language: SupportedLanguage, + testBodyPrinter: TestBodyPrinter, + packageName: String = "", + ): TestSuiteParser = when (language) { + SupportedLanguage.Java -> JavaJUnitTestSuiteParser( + packageName, + jUnitVersion, + testBodyPrinter, + ) + + SupportedLanguage.Kotlin -> KotlinJUnitTestSuiteParser( + packageName, + jUnitVersion, + testBodyPrinter, + ) + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt new file mode 100644 index 000000000..a896d273c --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestsAssemblerFactory.kt @@ -0,0 +1,18 @@ +package org.jetbrains.research.testspark.tools + +import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.TestSuiteParser +import org.jetbrains.research.testspark.tools.llm.generation.JUnitTestsAssembler + +class TestsAssemblerFactory { + companion object { + fun create( + indicator: CustomProgressIndicator, + generationData: TestGenerationData, + testSuiteParser: TestSuiteParser, + junitVersion: JUnitVersion, + ) = JUnitTestsAssembler(indicator, generationData, testSuiteParser, junitVersion) + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt index 3ba26b9c5..a7ef25eb2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt @@ -11,9 +11,9 @@ import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.utils.DataFilesUtil import org.jetbrains.research.testspark.data.IJTestCase -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper import org.jetbrains.research.testspark.services.TestsExecutionResultService import java.io.File @@ -21,68 +21,37 @@ object ToolUtils { val sep = File.separatorChar val pathSep = File.pathSeparatorChar - /** - * Retrieves the imports code from a given test suite code. - * - * @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned. - * @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result. - * @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned. - */ - fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String): MutableSet { - testSuiteCode ?: return mutableSetOf() - return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence() - .filter { it.contains("^import".toRegex()) } - .filterNot { it.contains("evosuite".toRegex()) } - .filterNot { it.contains("RunWith".toRegex()) } - .filterNot { it.contains(classFQN.toRegex()) }.toMutableSet() - } - - /** - * Retrieves the package declaration from the given test suite code. - * - * @param testSuiteCode The generated code of the test suite. - * @return The package declaration extracted from the test suite code, or an empty string if no package declaration was found. - */ -// get package from a generated code - fun getPackageFromTestSuiteCode(testSuiteCode: String?): String { - testSuiteCode ?: return "" - if (!testSuiteCode.contains("package")) return "" - val result = testSuiteCode.replace("\r\n", "\n").split("\n") - .filter { it.contains("^package".toRegex()) }.joinToString("").split("package ")[1].split(";")[0] - if (result.isBlank()) return "" - return result - } - /** * Saves the data related to test generation in the specified project's workspace. * * @param project The project in which the test generation data will be saved. * @param report The report object to be added to the test generation result list. - * @param packageLine The package declaration line of the test generation data. + * @param packageName The package declaration line of the test generation data. * @param importsCode The import statements code of the test generation data. */ fun saveData( project: Project, report: Report, - packageLine: String, + packageName: String, importsCode: MutableSet, fileUrl: String, generatedTestData: TestGenerationData, + language: SupportedLanguage = SupportedLanguage.Java, ) { generatedTestData.fileUrl = fileUrl - generatedTestData.packageLine = packageLine + generatedTestData.packageName = packageName generatedTestData.importsCode.addAll(importsCode) project.service().initExecutionResult(report.testCaseList.values.map { it.id }) for (testCase in report.testCaseList.values) { val code = testCase.testCode - testCase.testCode = JavaClassBuilderHelper.generateCode( + testCase.testCode = TestClassCodeGeneratorFactory.create(language).generateCode( project, getClassWithTestCaseName(testCase.testName), code, generatedTestData.importsCode, - generatedTestData.packageLine, + generatedTestData.packageName, generatedTestData.runWith, generatedTestData.otherInfo, generatedTestData, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt index 46b982ac1..529bb4b8e 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt @@ -5,7 +5,7 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.actions.controllers.TestGenerationController -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper @@ -88,7 +88,7 @@ class EvoSuite(override val name: String = "EvoSuite") : Tool { */ override fun generateTestsForLine(project: Project, psiHelper: PsiHelper, caretOffset: Int, fileUrl: String?, testSamplesCode: String, testGenerationController: TestGenerationController) { log.info("Starting tests generation for line by EvoSuite") - val selectedLine: Int = psiHelper.getSurroundingLine(caretOffset)!! + val selectedLine: Int = psiHelper.getSurroundingLineNumber(caretOffset)!! createPipeline(project, psiHelper, caretOffset, fileUrl, testGenerationController).runTestGeneration( getEvoSuiteProcessManager(project), FragmentToTestData( diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt index c1e5e6560..8c180f9df 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt @@ -15,10 +15,13 @@ import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteDefaultsBundle import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteMessagesBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.core.utils.CommandLineRunner -import org.jetbrains.research.testspark.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.IJReport import org.jetbrains.research.testspark.data.ProjectContext @@ -200,8 +203,8 @@ class EvoSuiteProcessManager( ToolUtils.saveData( project, IJReport(testGenerationResult), - ToolUtils.getPackageFromTestSuiteCode(testGenerationResult.testSuiteCode), - ToolUtils.getImportsCodeFromTestSuiteCode(testGenerationResult.testSuiteCode, classFQN), + getPackageFromTestSuiteCode(testGenerationResult.testSuiteCode, SupportedLanguage.Java), + getImportsCodeFromTestSuiteCode(testGenerationResult.testSuiteCode, classFQN), projectContext.fileUrlAsString!!, generatedTestsData, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt index 01f16176c..980707a2a 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt @@ -1,11 +1,12 @@ package org.jetbrains.research.testspark.tools.llm import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.actions.controllers.TestGenerationController import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.helpers.LLMHelper import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper @@ -23,6 +24,8 @@ import java.nio.file.Path */ class Llm(override val name: String = "LLM") : Tool { + private val log = Logger.getInstance(this::class.java) + /** * Returns an instance of the LLMProcessManager. * @@ -74,6 +77,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests for CLASS was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -107,6 +111,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests for METHOD was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -141,11 +146,12 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests for LINE was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return } - val selectedLine: Int = psiHelper.getSurroundingLine(caretOffset)!! + val selectedLine: Int = psiHelper.getSurroundingLineNumber(caretOffset)!! val codeType = FragmentToTestData(CodeType.LINE, selectedLine) createLLMPipeline(project, psiHelper, caretOffset, fileUrl, testGenerationController).runTestGeneration( LLMProcessManager( @@ -174,9 +180,7 @@ class Llm(override val name: String = "LLM") : Tool { fileUrl: String?, testGenerationController: TestGenerationController, ): Pipeline { - val cutPsiClass = psiHelper.getSurroundingClass(caretOffset)!! - val packageList = cutPsiClass.qualifiedName.split(".").dropLast(1) - val packageName = packageList.joinToString(".") + val packageName = psiHelper.getPackageName() return Pipeline(project, psiHelper, caretOffset, fileUrl, packageName, testGenerationController) } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt index 437ecd679..271cf4b49 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt @@ -57,6 +57,7 @@ class LlmSettingsArguments(private val project: Project) { fun getToken(): String = when (currentLLMPlatformName()) { llmSettingsState.openAIName -> llmSettingsState.openAIToken llmSettingsState.grazieName -> llmSettingsState.grazieToken + llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceToken else -> "" } @@ -68,6 +69,7 @@ class LlmSettingsArguments(private val project: Project) { fun getModel(): String = when (currentLLMPlatformName()) { llmSettingsState.openAIName -> llmSettingsState.openAIModel llmSettingsState.grazieName -> llmSettingsState.grazieModel + llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceModel else -> "" } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt index e1bcb67ec..1196016b2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt @@ -1,36 +1,27 @@ package org.jetbrains.research.testspark.tools.llm.generation import com.intellij.openapi.diagnostic.Logger -import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.TestSuiteParser import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.java.JavaJUnitTestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.kotlin.KotlinJUnitTestSuiteParser -import org.jetbrains.research.testspark.core.utils.Language -import org.jetbrains.research.testspark.core.utils.javaImportPattern -import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState /** * Assembler class for generating and organizing test cases. * - * @property project The project to which the tests belong. * @property indicator The progress indicator to display the progress of test generation. * @property log The logger for logging debug information. * @property lastTestCount The count of the last generated tests. */ class JUnitTestsAssembler( - val project: Project, val indicator: CustomProgressIndicator, - val generationData: TestGenerationData, + private val generationData: TestGenerationData, + private val testSuiteParser: TestSuiteParser, + val junitVersion: JUnitVersion, ) : TestsAssembler() { - private val llmSettingsState: LLMSettingsState - get() = project.getService(LLMSettingsService::class.java).state private val log: Logger = Logger.getInstance(this.javaClass) @@ -58,11 +49,8 @@ class JUnitTestsAssembler( } } - override fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? { - val junitVersion = llmSettingsState.junitVersion - - val parser = createTestSuiteParser(packageName, junitVersion, language) - val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(super.getContent()) + override fun assembleTestSuite(): TestSuiteGeneratedByLLM? { + val testSuite = testSuiteParser.parseTestSuite(super.getContent()) // save RunWith if (testSuite?.runWith?.isNotBlank() == true) { @@ -80,15 +68,4 @@ class JUnitTestsAssembler( testSuite?.testCases?.forEach { testCase -> log.info("Generated test case: $testCase") } return testSuite } - - private fun createTestSuiteParser( - packageName: String, - jUnitVersion: JUnitVersion, - language: Language, - ): TestSuiteParser { - return when (language) { - Language.Java -> JavaJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern) - Language.Kotlin -> KotlinJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern) - } - } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt index bb1dee0ff..f46dd5603 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt @@ -3,25 +3,32 @@ package org.jetbrains.research.testspark.tools.llm.generation import com.intellij.openapi.components.service import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project +import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.FeedbackCycleExecutionResult import org.jetbrains.research.testspark.core.generation.llm.LLMWithFeedbackCycle +import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeReductionStrategy import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.test.TestsPresenter import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.IJReport import org.jetbrains.research.testspark.data.ProjectContext import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.services.PluginSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestBodyPrinterFactory +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor +import org.jetbrains.research.testspark.tools.TestSuiteParserFactory +import org.jetbrains.research.testspark.tools.TestsAssemblerFactory import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager @@ -34,7 +41,6 @@ import java.nio.file.Path * and is responsible for generating tests using the LLM tool. * * @property project The project in which the test generation is being performed. - * @property prompt The prompt to be sent to the LLM tool. * @property testFileName The name of the generated test file. * @property log An instance of the logger class for logging purposes. * @property llmErrorManager An instance of the LLMErrorManager class. @@ -42,19 +48,23 @@ import java.nio.file.Path */ class LLMProcessManager( private val project: Project, - private val language: Language, + private val language: SupportedLanguage, private val promptManager: PromptManager, private val testSamplesCode: String, - projectSDKPath: Path? = null, + private val projectSDKPath: Path? = null, ) : ProcessManager { - private val llmSettingsState: LLMSettingsState - get() = project.getService(LLMSettingsService::class.java).state - private val testFileName: String = "GeneratedTest.java" + private val homeDirectory = + projectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + + private val testFileName: String = when (language) { + SupportedLanguage.Java -> "GeneratedTest.java" + SupportedLanguage.Kotlin -> "GeneratedTest.kt" + } private val log = Logger.getInstance(this::class.java) private val llmErrorManager: LLMErrorManager = LLMErrorManager() private val maxRequests = LlmSettingsArguments(project).maxLLMRequest() - private val testProcessor = TestProcessor(project, projectSDKPath) + private val testProcessor: TestsPersistentStorage = TestProcessor(project, projectSDKPath) /** * Runs the test generator process. @@ -91,16 +101,16 @@ class LLMProcessManager( val report = IJReport() // PROMPT GENERATION - val initialPromptMessage = promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) - - val testCompiler = testProcessor.testCompiler + val initialPromptMessage = + promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) // initiate a new RequestManager val requestManager = StandardRequestManagerFactory(project).getRequestManager(project) // adapter for the existing prompt reduction functionality val promptSizeReductionStrategy = object : PromptSizeReductionStrategy { - override fun isReductionPossible(): Boolean = promptManager.isPromptSizeReductionPossible(generatedTestsData) + override fun isReductionPossible(): Boolean = + promptManager.isPromptSizeReductionPossible(generatedTestsData) override fun reduceSizeAndGeneratePrompt(): String { if (!isReductionPossible()) { @@ -115,7 +125,7 @@ class LLMProcessManager( // adapter for the existing test case/test suite string representing functionality val testsPresenter = object : TestsPresenter { - private val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData) + private val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) override fun representTestSuite(testSuite: TestSuiteGeneratedByLLM): String { return testSuitePresenter.toStringWithoutExpectedException(testSuite) @@ -126,6 +136,29 @@ class LLMProcessManager( } } + // Creation of JUnit specific parser, printer and assembler + val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + val testBodyPrinter = TestBodyPrinterFactory.create(language) + val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( + jUnitVersion, + language, + testBodyPrinter, + packageName, + ) + val testsAssembler = TestsAssemblerFactory.create( + indicator, + generatedTestsData, + testSuiteParser, + jUnitVersion, + ) + + val testCompiler = TestCompilerFactory.create( + project, + jUnitVersion, + language, + homeDirectory, + ) + // Asking LLM to generate a test suite. Here we have a feedback cycle for LLM in case of wrong responses val llmFeedbackCycle = LLMWithFeedbackCycle( language = language, @@ -137,7 +170,7 @@ class LLMProcessManager( resultPath = generatedTestsData.resultPath, buildPath = buildPath, requestManager = requestManager, - testsAssembler = JUnitTestsAssembler(project, indicator, generatedTestsData), + testsAssembler = testsAssembler, testCompiler = testCompiler, testStorage = testProcessor, testsPresenter = testsPresenter, @@ -150,8 +183,10 @@ class LLMProcessManager( when (warning) { LLMWithFeedbackCycle.WarningType.TEST_SUITE_PARSING_FAILED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("emptyResponse"), project) + LLMWithFeedbackCycle.WarningType.NO_TEST_CASES_GENERATED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("emptyResponse"), project) + LLMWithFeedbackCycle.WarningType.COMPILATION_ERROR_OCCURRED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("compilationError"), project) } @@ -167,17 +202,21 @@ class LLMProcessManager( // store compilable test cases generatedTestsData.compilableTestCases.addAll(feedbackResponse.compilableTestCases) } + FeedbackCycleExecutionResult.NO_COMPILABLE_TEST_CASES_GENERATED -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("invalidLLMResult"), project, errorMonitor) } + FeedbackCycleExecutionResult.CANCELED -> { log.info("Process stopped") return null } + FeedbackCycleExecutionResult.PROVIDED_PROMPT_TOO_LONG -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("tooLongPromptRequest"), project, errorMonitor) return null } + FeedbackCycleExecutionResult.SAVING_TEST_FILES_ISSUE -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("savingTestFileIssue"), project, errorMonitor) } @@ -190,7 +229,7 @@ class LLMProcessManager( log.info("Save generated test suite and test cases into the project workspace") - val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData) + val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) val generatedTestSuite: TestSuiteGeneratedByLLM? = feedbackResponse.generatedTestSuite val testSuiteRepresentation = if (generatedTestSuite != null) testSuitePresenter.toString(generatedTestSuite) else null @@ -200,10 +239,11 @@ class LLMProcessManager( ToolUtils.saveData( project, report, - ToolUtils.getPackageFromTestSuiteCode(testSuiteCode = testSuiteRepresentation), - ToolUtils.getImportsCodeFromTestSuiteCode(testSuiteRepresentation, projectContext.classFQN!!), + getPackageFromTestSuiteCode(testSuiteCode = testSuiteRepresentation, language), + getImportsCodeFromTestSuiteCode(testSuiteRepresentation, projectContext.classFQN), projectContext.fileUrlAsString!!, generatedTestsData, + language, ) return UIContext(projectContext, generatedTestsData, requestManager, errorMonitor) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt index d7ac8f9f5..08e5be765 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt @@ -5,7 +5,6 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.openapi.util.Computable import com.intellij.openapi.util.TextRange -import com.intellij.psi.PsiDocumentManager import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.llm.LLMSettingsBundle import org.jetbrains.research.testspark.core.data.TestGenerationData @@ -15,7 +14,7 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptConfiguration import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptGenerationContext import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptTemplates -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.llm.JsonEncoding import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper @@ -31,7 +30,7 @@ import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager * A class that manages prompts for generating unit tests. * * @constructor Creates a PromptManager with the given parameters. - * @param psiHelper The PsiHelper in the context of witch the pipeline is executed. + * @param psiHelper The PsiHelper in the context of which the pipeline is executed. * @param caret The place of the caret. */ class PromptManager( @@ -39,6 +38,9 @@ class PromptManager( private val psiHelper: PsiHelper, private val caret: Int, ) { + /** + * The `classesToTest` is empty when we work with the function outside the class + */ private val classesToTest: List get() { val classesToTest = mutableListOf() @@ -52,7 +54,10 @@ class PromptManager( return classesToTest } - private val cut: PsiClassWrapper = classesToTest[0] + /** + * The `cut` is null when we work with the function outside the class. + */ + private val cut: PsiClassWrapper? = if (classesToTest.isNotEmpty()) classesToTest[0] else null private val llmSettingsState: LLMSettingsState get() = project.getService(LLMSettingsService::class.java).state @@ -79,7 +84,7 @@ class PromptManager( .toMap() val context = PromptGenerationContext( - cut = createClassRepresentation(cut), + cut = cut?.let { createClassRepresentation(it) }, classesToTest = classesToTest.map(this::createClassRepresentation).toList(), polymorphismRelations = polymorphismRelations, promptConfiguration = PromptConfiguration( @@ -110,7 +115,12 @@ class PromptManager( .map(this::createClassRepresentation) .toList() - promptGenerator.generatePromptForMethod(method, interestingClassesFromMethod, testSamplesCode) + promptGenerator.generatePromptForMethod( + method, + interestingClassesFromMethod, + testSamplesCode, + psiHelper.getPackageName(), + ) } CodeType.LINE -> { @@ -118,7 +128,7 @@ class PromptManager( val psiMethod = getPsiMethod(cut, getMethodDescriptor(cut, lineNumber))!! // get code of line under test - val document = PsiDocumentManager.getInstance(project).getDocument(cut.containingFile) + val document = psiHelper.getDocumentFromPsiFile() val lineStartOffset = document!!.getLineStartOffset(lineNumber - 1) val lineEndOffset = document.getLineEndOffset(lineNumber - 1) @@ -149,7 +159,7 @@ class PromptManager( signature = psiMethod.signature, name = psiMethod.name, text = psiMethod.text!!, - containingClassQualifiedName = psiMethod.containingClass!!.qualifiedName, + containingClassQualifiedName = psiMethod.containingClass?.qualifiedName ?: "", ) } @@ -210,7 +220,6 @@ class PromptManager( * * @param project The project context in which the PsiClasses exist. * @param interestingPsiClasses The set of PsiClassWrappers that are considered interesting. - * @param cutPsiClass The cut PsiClassWrapper to determine polymorphism relations against. * @return A mutable map where the key represents an interesting PsiClass and the value is a list of its detected subclasses. */ private fun getPolymorphismRelationsWithQualifiedNames( @@ -219,6 +228,9 @@ class PromptManager( ): MutableMap> { val polymorphismRelations: MutableMap> = mutableMapOf() + // assert(interestingPsiClasses.isEmpty()) + if (cut == null) return polymorphismRelations + interestingPsiClasses.add(cut) interestingPsiClasses.forEach { currentInterestingClass -> @@ -245,9 +257,14 @@ class PromptManager( * @return The matching PsiMethod if found, otherwise an empty string. */ private fun getPsiMethod( - psiClass: PsiClassWrapper, + psiClass: PsiClassWrapper?, methodDescriptor: String, ): PsiMethodWrapper? { + // Processing function outside the class + if (psiClass == null) { + val currentPsiMethod = psiHelper.getSurroundingMethod(caret)!! + return currentPsiMethod + } for (currentPsiMethod in psiClass.allMethods) { val file = psiClass.containingFile val psiHelper = PsiHelperProvider.getPsiHelper(file) @@ -268,9 +285,14 @@ class PromptManager( * @return the method descriptor as a String, or an empty string if no method is found */ private fun getMethodDescriptor( - psiClass: PsiClassWrapper, + psiClass: PsiClassWrapper?, lineNumber: Int, ): String { + // Processing function outside the class + if (psiClass == null) { + val currentPsiMethod = psiHelper.getSurroundingMethod(caret)!! + return psiHelper.generateMethodDescriptor(currentPsiMethod) + } for (currentPsiMethod in psiClass.allMethods) { if (currentPsiMethod.containsLine(lineNumber)) { val file = psiClass.containingFile diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt index 46daefc30..f05d55986 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt @@ -6,6 +6,7 @@ import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieRequestManager +import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFaceRequestManager import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIRequestManager interface RequestManagerFactory { @@ -20,6 +21,7 @@ class StandardRequestManagerFactory(private val project: Project) : RequestManag return when (val platform = LlmSettingsArguments(project).currentLLMPlatformName()) { llmSettingsState.openAIName -> OpenAIRequestManager(project) llmSettingsState.grazieName -> GrazieRequestManager(project) + llmSettingsState.huggingFaceName -> HuggingFaceRequestManager(project) else -> throw IllegalStateException("Unknown selected platform: $platform") } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt index c2267beb8..45581b8cf 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt @@ -62,14 +62,12 @@ class GrazieRequestManager(project: Project) : IJRequestManager(project) { } private fun getMessages(): List> { - val result = mutableListOf>() - chatHistory.forEach { + return chatHistory.map { val role = when (it.role) { ChatMessage.ChatRole.User -> "user" ChatMessage.ChatRole.Assistant -> "assistant" } - result.add(Pair(role, it.content)) + (role to it.content) } - return result } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFacePlatform.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFacePlatform.kt new file mode 100644 index 000000000..e5b93f588 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFacePlatform.kt @@ -0,0 +1,9 @@ +package org.jetbrains.research.testspark.tools.llm.generation.hf + +import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform + +class HuggingFacePlatform( + override val name: String = "HuggingFace", + override var token: String = "", + override var model: String = "", +) : LLMPlatform diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt new file mode 100644 index 000000000..6ef09950f --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt @@ -0,0 +1,33 @@ +package org.jetbrains.research.testspark.tools.llm.generation.hf + +import org.jetbrains.research.testspark.core.data.ChatMessage + +data class Parameters( + val topProbability: Double, + val temperature: Double, +) + +data class HuggingFaceRequestBody( + val messages: List, + val parameters: Parameters, +) + +/** + * Sets LLM settings required to send inference requests to HF + * For more info, see https://huggingface.co/docs/api-inference/en/detailed_parameters + */ +fun HuggingFaceRequestBody.toMap(): Map { + return mapOf( + "inputs" to this.messages.joinToString(separator = "\n") { it.content }, + // TODO: These parameters can be set by the user in the plugin's settings too. + "parameters" to mapOf( + "top_p" to this.parameters.topProbability, + "temperature" to this.parameters.temperature, + "min_length" to 4096, + "max_length" to 8192, + "max_new_tokens" to 250, + "max_time" to 120.0, + "return_full_text" to false, + ), + ) +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt new file mode 100644 index 000000000..e99a25bf2 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt @@ -0,0 +1,116 @@ +package org.jetbrains.research.testspark.tools.llm.generation.hf + +import com.google.gson.GsonBuilder +import com.google.gson.JsonParser +import com.intellij.openapi.project.Project +import com.intellij.util.io.HttpRequests +import com.intellij.util.io.HttpRequests.HttpStatusException +import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle +import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle +import org.jetbrains.research.testspark.core.data.ChatUserMessage +import org.jetbrains.research.testspark.core.monitor.ErrorMonitor +import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.TestsAssembler +import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments +import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager +import org.jetbrains.research.testspark.tools.llm.generation.IJRequestManager +import java.net.HttpURLConnection + +/** + * A class to manage requests sent to large language models hosted on HuggingFace + */ +class HuggingFaceRequestManager(project: Project) : IJRequestManager(project) { + private val url = "https://api-inference.huggingface.co/models/meta-llama/" + + // TODO: The user should be able to change these numbers in the plugin's settings + private val topProbability = 0.9 + private val temperature = 0.9 + + private val llmErrorManager = LLMErrorManager() + + override fun send( + prompt: String, + indicator: CustomProgressIndicator, + testsAssembler: TestsAssembler, + errorMonitor: ErrorMonitor, + ): SendResult { + val httpRequest = HttpRequests.post( + url + LlmSettingsArguments(project).getModel(), + "application/json", + ).tuner { + it.setRequestProperty("Authorization", "Bearer $token") + } + + // Add system prompt + if (chatHistory.size == 1) { + chatHistory[0] = ChatUserMessage( + createInstructionPrompt( + chatHistory[0].content, + ), + ) + } + + val llmRequestBody = HuggingFaceRequestBody(chatHistory, Parameters(topProbability, temperature)).toMap() + var sendResult = SendResult.OK + try { + httpRequest.connect { + it.write(GsonBuilder().disableHtmlEscaping().create().toJson(llmRequestBody)) + when (val responseCode = (it.connection as HttpURLConnection).responseCode) { + HttpURLConnection.HTTP_OK -> { + val text = it.reader.readLine() + val generatedTestCases = extractLLMGeneratedCode( + JsonParser.parseString(text).asJsonArray[0] + .asJsonObject["generated_text"].asString.trim(), + ) + testsAssembler.consume(generatedTestCases) + } + + HttpURLConnection.HTTP_INTERNAL_ERROR -> { + llmErrorManager.errorProcess( + LLMMessagesBundle.get("serverProblems"), + project, + errorMonitor, + ) + sendResult = SendResult.OTHER + } + + HttpURLConnection.HTTP_BAD_REQUEST -> { + llmErrorManager.errorProcess( + LLMMessagesBundle.get("hfServerError"), + project, + errorMonitor, + ) + sendResult = SendResult.OTHER + } + } + } + } catch (e: HttpStatusException) { + log.error { "Error in sending request: ${e.message}" } + } + return sendResult + } + + /** + * Creates the required prompt for Llama models. For more details see: + * https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + */ + private fun createInstructionPrompt(userMessage: String): String { + // TODO: This is Llama-specific and should support other LLMs hosted on HF too. + return "[INST] <> ${LLMDefaultsBundle.get("huggingFaceInitialSystemPrompt")} <> $userMessage [/INST]" + } + + /** + * Extracts code blocks in LLMs' response. + * Also, it handles the cases where the LLM-generated code does not end with ``` + */ + private fun extractLLMGeneratedCode(text: String): String { + // TODO: This method should support other languages other than Java. + val modifiedText = text.replace("```java", "```").replace("````", "```") + val tripleTickBlockIndex = modifiedText.indexOf("```") + val codePart = modifiedText.substring(tripleTickBlockIndex + 3) + val lines = codePart.lines() + val filteredLines = lines.filter { line -> line != "```" } + val code = filteredLines.joinToString("\n") + return "```\n$code\n```" + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt index 40e0c3fba..33138c4f8 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt @@ -1,9 +1,30 @@ package org.jetbrains.research.testspark.tools.llm.generation.openai -import org.jetbrains.research.testspark.core.data.ChatMessage +/** + * Adheres the naming of fields for OpenAI chat completion API and checks the correctness of a `role`. + *
+ * Use this class as a carrier of messages that should be sent to OpenAI API. + */ +data class OpenAIChatMessage(val role: String, val content: String) { + private companion object { + /** + * The API strictly defines the set of roles. + * The `function` role is omitted because it is already deprecated. + * + * See: https://platform.openai.com/docs/api-reference/chat/create + */ + val supportedRoles = listOf("user", "assistant", "system", "tool") + } + + init { + if (!supportedRoles.contains(role)) { + throw IllegalArgumentException("'$role' is not supported ${OpenAIChatMessage::class}. Available roles are: ${(supportedRoles.joinToString(", ") { "'$it'" })}") + } + } +} data class OpenAIRequestBody( val model: String, - val messages: List, + val messages: List, val stream: Boolean = true, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt index ed6607d3e..1d9d6a9a4 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt @@ -7,6 +7,7 @@ import com.intellij.openapi.project.Project import com.intellij.util.io.HttpRequests import com.intellij.util.io.HttpRequests.HttpStatusException import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle +import org.jetbrains.research.testspark.core.data.ChatMessage import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.TestsAssembler @@ -35,22 +36,29 @@ class OpenAIRequestManager(project: Project) : IJRequestManager(project) { errorMonitor: ErrorMonitor, ): SendResult { // Prepare the chat - val llmRequestBody = OpenAIRequestBody(LlmSettingsArguments(project).getModel(), chatHistory) + val messages = chatHistory.map { + val role = when (it.role) { + ChatMessage.ChatRole.User -> "user" + ChatMessage.ChatRole.Assistant -> "assistant" + } + OpenAIChatMessage(role, it.content) + } + + val llmRequestBody = OpenAIRequestBody(LlmSettingsArguments(project).getModel(), messages) var sendResult = SendResult.OK try { - httpRequest.connect { - it.write(GsonBuilder().create().toJson(llmRequestBody)) + httpRequest.connect { request -> + // send request to OpenAI API + request.write(GsonBuilder().create().toJson(llmRequestBody)) + + val connection = request.connection as HttpURLConnection // check response - when (val responseCode = (it.connection as HttpURLConnection).responseCode) { + when (val responseCode = connection.responseCode) { HttpURLConnection.HTTP_OK -> { - assembleLlmResponse( - httpRequest = it, - indicator, - testsAssembler, - ) + assembleLlmResponse(request, testsAssembler, indicator) } HttpURLConnection.HTTP_INTERNAL_ERROR -> { @@ -105,13 +113,12 @@ class OpenAIRequestManager(project: Project) : IJRequestManager(project) { */ private fun assembleLlmResponse( httpRequest: HttpRequests.Request, - indicator: CustomProgressIndicator, testsAssembler: TestsAssembler, + indicator: CustomProgressIndicator, ) { while (true) { if (ToolUtils.isProcessCanceled(indicator)) return - Thread.sleep(50L) var text = httpRequest.reader.readLine() if (text.isEmpty()) continue diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt index b1473b0c9..10aded741 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt @@ -3,12 +3,14 @@ package org.jetbrains.research.testspark.tools.llm.test import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper +import org.jetbrains.research.testspark.tools.TestClassCodeGeneratorFactory class JUnitTestSuitePresenter( private val project: Project, private val generatedTestsData: TestGenerationData, + private val language: SupportedLanguage, ) { /** * Returns a string representation of this object. @@ -34,12 +36,12 @@ class JUnitTestSuitePresenter( // Add each test testCases.forEach { testCase -> testBody += "$testCase\n" } - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, testFileName, testBody, imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -57,12 +59,12 @@ class JUnitTestSuitePresenter( testCaseIndex: Int, ): String = testSuite.run { - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, getClassWithTestCaseName(testCases[testCaseIndex].name), testCases[testCaseIndex].toStringWithoutExpectedException() + "\n", imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -81,12 +83,12 @@ class JUnitTestSuitePresenter( // Add each test (exclude expected exception) testCases.forEach { testCase -> testBody += "${testCase.toStringWithoutExpectedException()}\n" } - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, testFileName, testBody, imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -105,8 +107,8 @@ class JUnitTestSuitePresenter( fun getPrintablePackageString(testSuite: TestSuiteGeneratedByLLM): String { return testSuite.run { when { - packageString.isEmpty() || packageString.isBlank() -> "" - else -> packageString + packageName.isEmpty() || packageName.isBlank() -> "" + else -> packageName } } } diff --git a/src/main/resources/properties/llm/LLMDefaults.properties b/src/main/resources/properties/llm/LLMDefaults.properties index 156f15cbd..1eddae6e2 100644 --- a/src/main/resources/properties/llm/LLMDefaults.properties +++ b/src/main/resources/properties/llm/LLMDefaults.properties @@ -4,6 +4,10 @@ openAIModel= grazieName=AI Assistant JetBrains grazieToken= grazieModel= +huggingFaceName=HuggingFace +huggingFaceToken= +huggingFaceModel= +huggingFaceInitialSystemPrompt=You are a helpful and honest code and programming assistant. Please, respond concisely and truthfully. maxLLMRequest=3 maxInputParamsDepth=2 maxPolyDepth=2 diff --git a/src/main/resources/properties/llm/LLMMessages.properties b/src/main/resources/properties/llm/LLMMessages.properties index db087d5c1..3502840ab 100644 --- a/src/main/resources/properties/llm/LLMMessages.properties +++ b/src/main/resources/properties/llm/LLMMessages.properties @@ -14,4 +14,5 @@ grazieError=Grazie test generation feature is not available in this build. removeTemplateMessage=Choose another default template to remove this one. removeTemplateTitle=Can't Be Removed defaultPromptIsNotValidMessage=Default prompt is not valid. Fix it, please. -defaultPromptIsNotValidTitle=Incorrect Prompt State \ No newline at end of file +defaultPromptIsNotValidTitle=Incorrect Prompt State +hfServerError=The selected model may need an HF PRO subscription to use! \ No newline at end of file From fbaf10296a77ca4d6e51b1ce63e4afb9113d3f88 Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 21:05:29 +0200 Subject: [PATCH 10/19] fix compalation bug --- .../jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt index 523903f96..fd8a78a1b 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt @@ -10,7 +10,10 @@ import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiFile import com.intellij.psi.util.parentOfType -import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.psi.KtClassOrObject +import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtFunction +import org.jetbrains.kotlin.psi.KtPsiUtil import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.langwrappers.CodeTypeDisplayName From 70bd7b2dce1b499bacd23c5a0884b1ff5f036387 Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 21:08:34 +0200 Subject: [PATCH 11/19] fixing compilation bug --- .../services/kotlin/KotlinTestCaseDisplayService.kt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt index a80952747..a77edd16d 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt @@ -39,7 +39,8 @@ import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.display.TestCasePanelFactory import org.jetbrains.research.testspark.display.TopButtonsPanelFactory import org.jetbrains.research.testspark.helpers.ReportHelper -import org.jetbrains.research.testspark.helpers.kotlin.KotlinClassBuilderHelper +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeAnalyzer +import org.jetbrains.research.testspark.helpers.kotlin.KotlinTestClassCodeGenerator import org.jetbrains.research.testspark.kotlin.KotlinPsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.services.CoverageVisualisationService @@ -417,8 +418,8 @@ class KotlinTestCaseDisplayService(private val project: Project) : TestCaseDispl // insert tests to a code testCaseComponents.reversed().forEach { val testMethodCode = - KotlinClassBuilderHelper.extractFirstTestMethodCode( - KotlinClassBuilderHelper.formatCode( + KotlinTestClassCodeAnalyzer.extractFirstTestMethodCode( + KotlinTestClassCodeGenerator.formatCode( project, it.replace("\r\n", "\n") .replace("verifyException(", "// verifyException("), From e0f01ddd2d2de28033cb42766b84cf0b7ba99ed8 Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 21:17:33 +0200 Subject: [PATCH 12/19] TopButtonsPanelStrategy refactoring --- build.gradle.kts | 1 + .../testspark/core/data/TestGenerationData.kt | 4 +- .../generation/llm/LLMWithFeedbackCycle.kt | 14 +- .../testspark/core/generation/llm/Utils.kt | 48 +- .../generation/llm/network/RequestManager.kt | 12 +- .../generation/llm/prompt/PromptBuilder.kt | 13 +- .../generation/llm/prompt/PromptGenerator.kt | 6 +- .../llm/prompt/configuration/Configuration.kt | 5 +- .../testspark/core/test/TestCompiler.kt | 60 +- .../testspark/core/test/TestsAssembler.kt | 8 +- .../core/test/TestsPersistentStorage.kt | 1 + .../core/test/data/TestCaseGeneratedByLLM.kt | 29 +- .../core/test/data/TestSuiteGeneratedByLLM.kt | 4 +- .../JavaTestCompilationDependencies.kt | 30 - .../core/test/parsers/TestSuiteParser.kt | 20 - .../parsers/java/JavaJUnitTestSuiteParser.kt | 22 - .../kotlin/KotlinJUnitTestSuiteParser.kt | 22 - .../JUnitTestSuiteParserStrategy.kt | 173 ------ .../research/testspark/core/utils/Language.kt | 8 - .../research/testspark/core/utils/Patterns.kt | 10 +- .../kotlin/KotlinJUnitTestSuiteParserTest.kt | 161 +++++- .../testspark/java/JavaPsiClassWrapper.kt | 32 +- .../research/testspark/java/JavaPsiHelper.kt | 57 +- .../testspark/kotlin/KotlinPsiClassWrapper.kt | 40 +- .../testspark/kotlin/KotlinPsiHelper.kt | 97 ++-- .../kotlin/KotlinPsiMethodWrapper.kt | 20 + langwrappers/build.gradle.kts | 2 - .../testspark/langwrappers/PsiComponents.kt | 42 +- .../testspark/actions/TestSparkAction.kt | 85 +-- .../actions/llm/LLMSampleSelectorFactory.kt | 5 +- .../actions/llm/LLMSetupPanelFactory.kt | 6 +- .../actions/llm/TestSamplePanelFactory.kt | 4 +- .../testspark/appstarter/TestSparkStarter.kt | 14 +- .../research/testspark/data/CodeType.kt | 8 - .../testspark/data/FragmentToTestData.kt | 2 + .../testspark/display/TestCasePanelFactory.kt | 53 +- .../display/TopButtonsPanelFactory.kt | 196 +------ .../testspark/helpers/CoverageHelper.kt | 6 +- .../helpers/JavaClassBuilderHelper.kt | 204 ------- .../research/testspark/helpers/LLMHelper.kt | 55 +- .../CoverageToolWindowDisplayService.kt | 0 .../services/TestCaseDisplayService.kt | 527 +----------------- .../settings/llm/LLMSettingsComponent.kt | 2 +- .../settings/llm/LLMSettingsConfigurable.kt | 12 + .../settings/llm/LLMSettingsState.kt | 6 + .../testspark/tools/LibraryPathsProvider.kt | 4 +- .../research/testspark/tools/Pipeline.kt | 45 +- .../testspark/tools/TestCompilerFactory.kt | 17 +- .../research/testspark/tools/TestProcessor.kt | 61 +- .../research/testspark/tools/ToolUtils.kt | 45 +- .../testspark/tools/evosuite/EvoSuite.kt | 4 +- .../generation/EvoSuiteProcessManager.kt | 9 +- .../research/testspark/tools/llm/Llm.kt | 14 +- .../tools/llm/LlmSettingsArguments.kt | 2 + .../llm/generation/JUnitTestsAssembler.kt | 35 +- .../tools/llm/generation/LLMProcessManager.kt | 76 ++- .../tools/llm/generation/PromptManager.kt | 44 +- .../llm/generation/RequestManagerFactory.kt | 2 + .../generation/grazie/GrazieRequestManager.kt | 6 +- .../generation/openai/OpenAIRequestBody.kt | 25 +- .../generation/openai/OpenAIRequestManager.kt | 29 +- .../tools/llm/test/JUnitTestSuitePresenter.kt | 20 +- .../properties/llm/LLMDefaults.properties | 4 + .../properties/llm/LLMMessages.properties | 3 +- 64 files changed, 882 insertions(+), 1689 deletions(-) delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt delete mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt delete mode 100644 src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt diff --git a/build.gradle.kts b/build.gradle.kts index 13da233c4..5e6621e29 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -157,6 +157,7 @@ dependencies { // https://mvnrepository.com/artifact/org.mockito/mockito-all testImplementation("org.mockito:mockito-all:1.10.19") + testImplementation("org.mockito.kotlin:mockito-kotlin:5.1.0") // https://mvnrepository.com/artifact/net.jqwik/jqwik testImplementation("net.jqwik:jqwik:1.6.5") diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt index d11f346d5..a35212cb1 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt @@ -16,7 +16,7 @@ data class TestGenerationData( // Code required of imports and package for generated tests var importsCode: MutableSet = mutableSetOf(), - var packageLine: String = "", + var packageName: String = "", var runWith: String = "", var otherInfo: String = "", @@ -37,7 +37,7 @@ data class TestGenerationData( resultName = "" fileUrl = "" importsCode = mutableSetOf() - packageLine = "" + packageName = "" runWith = "" otherInfo = "" polyDepthReducing = 0 diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt index 0c8a428aa..973b26e7a 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt @@ -10,13 +10,13 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeRed import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.test.TestsPresenter import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language import java.io.File enum class FeedbackCycleExecutionResult { @@ -45,7 +45,7 @@ data class FeedbackResponse( class LLMWithFeedbackCycle( private val report: Report, - private val language: Language, + private val language: SupportedLanguage, private val initialPromptMessage: String, private val promptSizeReductionStrategy: PromptSizeReductionStrategy, // filename in which the test suite is saved in result path @@ -167,13 +167,15 @@ class LLMWithFeedbackCycle( generatedTestSuite.updateTestCases(compilableTestCases.toMutableList()) } else { for (testCaseIndex in generatedTestSuite.testCases.indices) { - val testCaseFilename = - "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" + val testCaseFilename = when (language) { + SupportedLanguage.Java -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" + SupportedLanguage.Kotlin -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.kt" + } val testCaseRepresentation = testsPresenter.representTestCase(generatedTestSuite, testCaseIndex) val saveFilepath = testStorage.saveGeneratedTest( - generatedTestSuite.packageString, + generatedTestSuite.packageName, testCaseRepresentation, resultPath, testCaseFilename, @@ -184,7 +186,7 @@ class LLMWithFeedbackCycle( } val generatedTestSuitePath: String = testStorage.saveGeneratedTest( - generatedTestSuite.packageString, + generatedTestSuite.packageName, testsPresenter.representTestSuite(generatedTestSuite), resultPath, testSuiteFilename, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt index 76cb74c17..1942a6a86 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt @@ -4,13 +4,47 @@ import org.jetbrains.research.testspark.core.generation.llm.network.RequestManag import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.utils.javaPackagePattern +import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern import java.util.Locale // TODO: find a better place for the below functions +/** + * Retrieves the package declaration from the given test suite code for any language. + * + * @param testSuiteCode The generated code of the test suite. + * @return The package name extracted from the test suite code, or an empty string if no package declaration was found. + */ +fun getPackageFromTestSuiteCode(testSuiteCode: String?, language: SupportedLanguage): String { + testSuiteCode ?: return "" + return when (language) { + SupportedLanguage.Kotlin -> kotlinPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty() + SupportedLanguage.Java -> javaPackagePattern.find(testSuiteCode)?.groups?.get(1)?.value.orEmpty() + } +} + +/** + * Retrieves the imports code from a given test suite code. + * + * @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned. + * @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result. + * @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned. + */ +fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String?): MutableSet { + testSuiteCode ?: return mutableSetOf() + return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence() + .filter { it.contains("^import".toRegex()) } + .filterNot { it.contains("evosuite".toRegex()) } + .filterNot { it.contains("RunWith".toRegex()) } + // classFQN will be null for the top level function + .filterNot { classFQN != null && it.contains(classFQN.toRegex()) } + .toMutableSet() +} + /** * Returns the generated class name for a given test case. * @@ -39,7 +73,7 @@ fun getClassWithTestCaseName(testCaseName: String): String { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun executeTestCaseModificationRequest( - language: Language, + language: SupportedLanguage, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -50,15 +84,7 @@ fun executeTestCaseModificationRequest( // Update Token information val prompt = "For this test:\n ```\n $testCase\n ```\nPerform the following task: $task" - var packageName = "" - testCase.split("\n")[0].let { - if (it.startsWith("package")) { - packageName = it - .removePrefix("package ") - .removeSuffix(";") - .trim() - } - } + val packageName = getPackageFromTestSuiteCode(testCase, language) val response = requestManager.request( language, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt index 689eec798..441e51231 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt @@ -7,8 +7,8 @@ import org.jetbrains.research.testspark.core.data.ChatUserMessage import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestsAssembler -import org.jetbrains.research.testspark.core.utils.Language abstract class RequestManager(var token: String) { enum class SendResult { @@ -31,7 +31,7 @@ abstract class RequestManager(var token: String) { * @return the generated TestSuite, or null and prompt message */ open fun request( - language: Language, + language: SupportedLanguage, prompt: String, indicator: CustomProgressIndicator, packageName: String, @@ -65,7 +65,7 @@ abstract class RequestManager(var token: String) { open fun processResponse( testsAssembler: TestsAssembler, packageName: String, - language: Language, + language: SupportedLanguage, ): LLMResponse { // save the full response in the chat history val response = testsAssembler.getContent() @@ -78,7 +78,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite() return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) @@ -97,7 +97,7 @@ abstract class RequestManager(var token: String) { open fun processUserFeedbackResponse( testsAssembler: TestsAssembler, packageName: String, - language: Language, + language: SupportedLanguage, ): LLMResponse { val response = testsAssembler.getContent() @@ -108,7 +108,7 @@ abstract class RequestManager(var token: String) { return LLMResponse(ResponseErrorCode.EMPTY_LLM_RESPONSE, null) } - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite(packageName, language) + val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite() return if (testSuiteGeneratedByLLM == null) { LLMResponse(ResponseErrorCode.TEST_SUITE_PARSING_FAILURE, null) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 278d58655..036e87a0d 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -78,7 +78,7 @@ internal class PromptBuilder(private var prompt: String) { fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" } for (interestingClass in interestingClasses) { - if (interestingClass.qualifiedName.startsWith("java")) { + if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) { continue } @@ -88,7 +88,9 @@ internal class PromptBuilder(private var prompt: String) { // Skip java methods // TODO: checks for java methods should be done by a caller to make // this class as abstract and language agnostic as possible. - if (method.containingClassQualifiedName.startsWith("java")) { + if (method.containingClassQualifiedName.startsWith("java") || + method.containingClassQualifiedName.startsWith("kotlin") + ) { continue } @@ -106,8 +108,11 @@ internal class PromptBuilder(private var prompt: String) { ) = apply { val keyword = "\$${PromptKeyword.POLYMORPHISM.text}" if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) { - var fullText = "" - + // If polymorphismRelations is not empty, we add an instruction to avoid mocking classes if an instantiation of a sub-class is applicable + var fullText = when { + polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable" + else -> "" + } polymorphismRelations.forEach { entry -> for (currentSubClass in entry.value) { val subClassTypeName = when (currentSubClass.classType) { diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt index 3afbd3cff..72340867a 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt @@ -19,7 +19,7 @@ class PromptGenerator( fun generatePromptForClass(interestingClasses: List, testSamplesCode: String): String { val prompt = PromptBuilder(promptTemplates.classPrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName(context.cut.qualifiedName) + .insertName(context.cut!!.qualifiedName) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) .insertCodeUnderTest(context.cut.fullText, context.classesToTest) @@ -44,10 +44,12 @@ class PromptGenerator( method: MethodRepresentation, interestingClassesFromMethod: List, testSamplesCode: String, + packageName: String, ): String { + val name = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}" val prompt = PromptBuilder(promptTemplates.methodPrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName("${context.cut.qualifiedName}.${method.name}") + .insertName(name) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) .insertCodeUnderTest(method.text, context.classesToTest) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt index 4094de1aa..6b87e8941 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt @@ -10,7 +10,10 @@ import org.jetbrains.research.testspark.core.data.ClassType * @property polymorphismRelations A map where the key represents a ClassRepresentation object and the value is a list of its detected subclasses. */ data class PromptGenerationContext( - val cut: ClassRepresentation, + /** + * The cut is null when we want to generate tests for top-level function + */ + val cut: ClassRepresentation?, val classesToTest: List, val polymorphismRelations: Map>, val promptConfiguration: PromptConfiguration, diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt index bc4d40617..b49281aaf 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt @@ -1,32 +1,24 @@ package org.jetbrains.research.testspark.core.test -import io.github.oshai.kotlinlogging.KotlinLogging import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.CommandLineRunner import org.jetbrains.research.testspark.core.utils.DataFilesUtil -import java.io.File data class TestCasesCompilationResult( val allTestCasesCompilable: Boolean, val compilableTestCases: MutableSet, ) -/** - * TestCompiler is a class that is responsible for compiling generated test cases using the proper javac. - * It provides methods for compiling test cases and code files. - */ -open class TestCompiler( - private val javaHomeDirectoryPath: String, +abstract class TestCompiler( private val libPaths: List, private val junitLibPaths: List, ) { - private val log = KotlinLogging.logger { this::class.java } - /** - * Compiles the generated files with test cases using the proper javac. + * Compiles a list of test cases and returns the compilation result. * - * @return true if all the provided test cases are successfully compiled, - * otherwise returns false. + * @param generatedTestCasesPaths A list of file paths where the generated test cases are located. + * @param buildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case. + * @param testCases A mutable list of `TestCaseGeneratedByLLM` objects representing the test cases to be compiled. + * @return A `TestCasesCompilationResult` object containing the overall compilation success status and a set of compilable test cases. */ fun compileTestCases( generatedTestCasesPaths: List, @@ -51,45 +43,11 @@ open class TestCompiler( * Compiles the code at the specified path using the provided project build path. * * @param path The path of the code file to compile. - * @param projectBuildPath The project build path to use during compilation. + * @param projectBuildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case. * @return A pair containing a boolean value indicating whether the compilation was successful (true) or not (false), * and a string message describing any error encountered during compilation. */ - fun compileCode(path: String, projectBuildPath: String): Pair { - // find the proper javac - val javaCompile = File(javaHomeDirectoryPath).walk() - .filter { - val isCompilerName = if (DataFilesUtil.isWindows()) it.name.equals("javac.exe") else it.name.equals("javac") - isCompilerName && it.isFile - } - .firstOrNull() - - if (javaCompile == null) { - val msg = "Cannot find java compiler 'javac' at '$javaHomeDirectoryPath'" - log.error { msg } - throw RuntimeException(msg) - } - - println("javac found at '${javaCompile.absolutePath}'") - - // compile file - val errorMsg = CommandLineRunner.run( - arrayListOf( - javaCompile.absolutePath, - "-cp", - "\"${getPath(projectBuildPath)}\"", - path, - ), - ) - - log.info { "Error message: '$errorMsg'" } - - // create .class file path - val classFilePath = path.replace(".java", ".class") - - // check is .class file exists - return Pair(File(classFilePath).exists(), errorMsg) - } + abstract fun compileCode(path: String, projectBuildPath: String): Pair /** * Generates the path for the command by concatenating the necessary paths. @@ -97,7 +55,7 @@ open class TestCompiler( * @param buildPath The path of the build file. * @return The generated path as a string. */ - fun getPath(buildPath: String): String { + fun getClassPaths(buildPath: String): String { // create the path for the command val separator = DataFilesUtil.classpathSeparator val dependencyLibPath = libPaths.joinToString(separator.toString()) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt index 6e5a4e127..0d9c672de 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt @@ -1,7 +1,6 @@ package org.jetbrains.research.testspark.core.test import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language abstract class TestsAssembler { private var rawText = "" @@ -33,10 +32,9 @@ abstract class TestsAssembler { } /** - * Extracts test cases from raw text and generates a TestSuite using the given package name. + * Extracts test cases from raw text and generates a TestSuite. * - * @param packageName The package name to be set in the generated TestSuite. - * @return A TestSuiteGeneratedByLLM object containing the extracted test cases and package name. + * @return A TestSuiteGeneratedByLLM object containing information about the extracted test cases. */ - abstract fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? + abstract fun assembleTestSuite(): TestSuiteGeneratedByLLM? } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt index 1673fea4a..b9d50132c 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsPersistentStorage.kt @@ -4,6 +4,7 @@ package org.jetbrains.research.testspark.core.test * The TestPersistentStorage interface represents a contract for saving generated tests to a specified file system location. */ interface TestsPersistentStorage { + /** * Save the generated tests to a specified directory. * diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt index 6ef9f6907..2a565e82e 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt @@ -1,5 +1,7 @@ package org.jetbrains.research.testspark.core.test.data +import org.jetbrains.research.testspark.core.test.TestBodyPrinter + /** * * Represents a test case generated by LLM. @@ -11,6 +13,7 @@ data class TestCaseGeneratedByLLM( var expectedException: String = "", var throwsException: String = "", var lines: MutableList = mutableListOf(), + val printTestBodyStrategy: TestBodyPrinter, ) { /** @@ -104,31 +107,7 @@ data class TestCaseGeneratedByLLM( * @return a string containing the body of test case */ private fun printTestBody(testInitiatedText: String): String { - var testFullText = testInitiatedText - - // start writing the test signature - testFullText += "\n\tpublic void $name() " - - // add throws exception if exists - if (throwsException.isNotBlank()) { - testFullText += "throws $throwsException" - } - - // start writing the test lines - testFullText += "{\n" - - // write each line - lines.forEach { line -> - testFullText += when (line.type) { - TestLineType.BREAK -> "\t\t\n" - else -> "\t\t${line.text}\n" - } - } - - // close test case - testFullText += "\t}\n" - - return testFullText + return printTestBodyStrategy.printTestBody(testInitiatedText, lines, throwsException, name) } /** diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt index 211063bb7..4fac9b8b9 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt @@ -4,12 +4,12 @@ package org.jetbrains.research.testspark.core.test.data * Represents a test suite generated by LLM. * * @property imports The set of import statements in the test suite. - * @property packageString The package string of the test suite. + * @property packageName The package name of the test suite. * @property testCases The list of test cases in the test suite. */ data class TestSuiteGeneratedByLLM( var imports: Set = emptySet(), - var packageString: String = "", + var packageName: String = "", var runWith: String = "", var otherInfo: String = "", var testCases: MutableList = mutableListOf(), diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt deleted file mode 100644 index 2e78b0b50..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/dependencies/JavaTestCompilationDependencies.kt +++ /dev/null @@ -1,30 +0,0 @@ -package org.jetbrains.research.testspark.core.test.data.dependencies - -import org.jetbrains.research.testspark.core.data.JarLibraryDescriptor - -/** - * The class represents a list of dependencies required for java test compilation. - * The libraries listed are used during test suite/test case compilation. - */ -class JavaTestCompilationDependencies { - companion object { - fun getJarDescriptors() = listOf( - JarLibraryDescriptor( - "mockito-core-5.0.0.jar", - "https://repo1.maven.org/maven2/org/mockito/mockito-core/5.0.0/mockito-core-5.0.0.jar", - ), - JarLibraryDescriptor( - "hamcrest-core-1.3.jar", - "https://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", - ), - JarLibraryDescriptor( - "byte-buddy-1.14.6.jar", - "https://repo1.maven.org/maven2/net/bytebuddy/byte-buddy/1.14.6/byte-buddy-1.14.6.jar", - ), - JarLibraryDescriptor( - "byte-buddy-agent-1.14.6.jar", - "https://repo1.maven.org/maven2/net/bytebuddy/byte-buddy-agent/1.14.6/byte-buddy-agent-1.14.6.jar", - ), - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt deleted file mode 100644 index a0551ed7c..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/TestSuiteParser.kt +++ /dev/null @@ -1,20 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers - -import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM - -data class TestCaseParseResult( - val testCase: TestCaseGeneratedByLLM?, - val errorMessage: String, - val errorOccurred: Boolean, -) - -interface TestSuiteParser { - /** - * Extracts test cases from raw text and generates a test suite using the given package name. - * - * @param rawText The raw text provided by the LLM that contains the generated test cases. - * @return A GeneratedTestSuite instance containing the extracted test cases. - */ - fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt deleted file mode 100644 index a8728bbf2..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/java/JavaJUnitTestSuiteParser.kt +++ /dev/null @@ -1,22 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers.java - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.strategies.JUnitTestSuiteParserStrategy - -class JavaJUnitTestSuiteParser( - private val packageName: String, - private val junitVersion: JUnitVersion, - private val importPattern: Regex, -) : TestSuiteParser { - override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { - return JUnitTestSuiteParserStrategy.parseTestSuite( - rawText, - junitVersion, - importPattern, - packageName, - testNamePattern = "void", - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt deleted file mode 100644 index 09bdbc627..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParser.kt +++ /dev/null @@ -1,22 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers.kotlin - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.strategies.JUnitTestSuiteParserStrategy - -class KotlinJUnitTestSuiteParser( - private val packageName: String, - private val junitVersion: JUnitVersion, - private val importPattern: Regex, -) : TestSuiteParser { - override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { - return JUnitTestSuiteParserStrategy.parseTestSuite( - rawText, - junitVersion, - importPattern, - packageName, - testNamePattern = "fun", - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt deleted file mode 100644 index 98c6827c5..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/parsers/strategies/JUnitTestSuiteParserStrategy.kt +++ /dev/null @@ -1,173 +0,0 @@ -package org.jetbrains.research.testspark.core.test.parsers.strategies - -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM -import org.jetbrains.research.testspark.core.test.data.TestLine -import org.jetbrains.research.testspark.core.test.data.TestLineType -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestCaseParseResult - -class JUnitTestSuiteParserStrategy { - companion object { - fun parseTestSuite( - rawText: String, - junitVersion: JUnitVersion, - importPattern: Regex, - packageName: String, - testNamePattern: String, - ): TestSuiteGeneratedByLLM? { - if (rawText.isBlank()) { - return null - } - - try { - var rawCode = rawText - - if (rawText.contains("```")) { - rawCode = rawText.split("```")[1] - } - - // save imports - val imports = importPattern.findAll(rawCode, 0) - .map { it.groupValues[0] } - .toSet() - - // save RunWith - val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" - - val testSet: MutableList = rawCode.split("@Test").toMutableList() - - // save annotations and pre-set methods - val otherInfo: String = run { - val otherInfoList = testSet.removeAt(0).split("{").toMutableList() - otherInfoList.removeFirst() - val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" - otherInfo.ifBlank { "" } - } - - // Save the main test cases - val testCases: MutableList = mutableListOf() - val testCaseParser = JUnitTestCaseParser() - - testSet.forEach ca@{ - val rawTest = "@Test$it" - - val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) - val result: TestCaseParseResult = - testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern) // /// - - if (result.errorOccurred) { - println("WARNING: ${result.errorMessage}") - return@ca - } - - val currentTest = result.testCase!! - - // TODO: make logging work - // log.info("New test case: $currentTest") - println("New test case: $currentTest") - - testCases.add(currentTest) - } - - val testSuite = TestSuiteGeneratedByLLM( - imports = imports, - packageString = packageName, - runWith = runWith, - otherInfo = otherInfo, - testCases = testCases, - ) - - return testSuite - } catch (e: Exception) { - return null - } - } - } -} - -private class JUnitTestCaseParser { - fun parse(rawTest: String, isLastTestCaseInTestSuite: Boolean, testNamePattern: String): TestCaseParseResult { - var expectedException = "" - var throwsException = "" - val testLines: MutableList = mutableListOf() - - // Get expected Exception - if (rawTest.startsWith("@Test(expected =")) { - expectedException = rawTest.split(")")[0].trim() - } - - // Get unexpected exceptions - /* Each test case should follow fun {...} - Tests do not return anything so it is safe to consider that void always appears before test case name - */ - val voidString = testNamePattern - if (!rawTest.contains(voidString)) { - return TestCaseParseResult( - testCase = null, - errorMessage = "The raw Test does not contain $voidString:\n $rawTest", - errorOccurred = true, - ) - } - val interestingPartOfSignature = rawTest.split(voidString)[1] - .split("{")[0] - .split("()")[1] - .trim() - - if (interestingPartOfSignature.contains("throws")) { - throwsException = interestingPartOfSignature.split("throws")[1].trim() - } - - // Get test name - val testName: String = rawTest.split(voidString)[1] - .split("()")[0] - .trim() - - // Get test body and remove opening bracket - var testBody = rawTest.split("{").toMutableList().apply { removeFirst() } - .joinToString("{").trim() - - // remove closing bracket - val tempList = testBody.split("}").toMutableList() - tempList.removeLast() - - if (isLastTestCaseInTestSuite) { - // it is the last test, thus we should remove another closing bracket - if (tempList.isNotEmpty()) { - tempList.removeLast() - } else { - println("WARNING: the final test does not have the enclosing bracket:\n $testBody") - } - } - - testBody = tempList.joinToString("}") - - // Save each line - val rawLines = testBody.split("\n").toMutableList() - rawLines.forEach { rawLine -> - val line = rawLine.trim() - - val type: TestLineType = when { - line.startsWith("//") -> TestLineType.COMMENT - line.isBlank() -> TestLineType.BREAK - line.lowercase().startsWith("assert") -> TestLineType.ASSERTION - else -> TestLineType.CODE - } - - testLines.add(TestLine(type, line)) - } - - val currentTest = TestCaseGeneratedByLLM( - name = testName, - expectedException = expectedException, - throwsException = throwsException, - lines = testLines, - ) - - return TestCaseParseResult( - testCase = currentTest, - errorMessage = "", - errorOccurred = false, - ) - } -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt deleted file mode 100644 index 250ec7cba..000000000 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Language.kt +++ /dev/null @@ -1,8 +0,0 @@ -package org.jetbrains.research.testspark.core.utils - -/** - * Language ID string should be the same as the language name in com.intellij.lang.Language - */ -enum class Language(val languageId: String) { - Java("JAVA"), Kotlin("Kotlin") -} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt index 95903bf8c..fb1da6841 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/utils/Patterns.kt @@ -6,9 +6,17 @@ val javaImportPattern = options = setOf(RegexOption.MULTILINE), ) +/** + * Parse all the possible Kotlin import patterns + * + * import org.mockito.Mockito.`when` + * import kotlin.math.cos + * import kotlin.math.* + * import kotlin.math.PI as piValue + */ val kotlinImportPattern = Regex( - pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?", + pattern = "^import\\s+((?:[a-zA-Z_]\\w*\\.)*(?:\\w*\\.?)*)?(\\*)?( as \\w*)?(`\\w*`)?", options = setOf(RegexOption.MULTILINE), ) diff --git a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt index 2ebcde0c9..63fbd0abc 100644 --- a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt +++ b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt @@ -2,14 +2,17 @@ package org.jetbrains.research.testspark.core.test.parsers.kotlin import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.kotlinImportPattern +import org.jetbrains.research.testspark.core.test.kotlin.KotlinJUnitTestSuiteParser +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestBodyPrinter +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test -import kotlin.test.assertNotNull class KotlinJUnitTestSuiteParserTest { @Test - fun testFunction() { + fun testParseTestSuite() { val text = """ ```kotlin import org.junit.jupiter.api.Assertions.* @@ -109,17 +112,149 @@ class KotlinJUnitTestSuiteParserTest { } ``` """.trimIndent() - val parser = KotlinJUnitTestSuiteParser("org.my.package", JUnitVersion.JUnit5, kotlinImportPattern) + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assertTrue(testSuite!!.imports.contains("import org.mockito.Mockito.*")) + assertTrue(testSuite.imports.contains("import org.test.Message as TestMessage")) + assertTrue(testSuite.imports.contains("import org.mockito.kotlin.mock")) + + val expectedTestCasesNames = listOf( + "compileTestCases_AllCompilableTest", + "compileTestCases_NoneCompilableTest", + "compileTestCases_SomeCompilableTest", + "compileTestCases_EmptyTestCasesTest", + "compileTestCases_omg", + ) + + testSuite.testCases.forEachIndexed { index, testCase -> + val expected = expectedTestCasesNames[index] + assertEquals(expected, testCase.name) { "${index + 1}st test case has incorrect name" } + } + + assertTrue(testSuite.testCases[4].expectedException.isNotBlank()) + } + + @Test + fun testParseEmptyTestSuite() { + val text = """ + ```kotlin + package com.example.testsuite + + class EmptyTestClass { + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assertEquals(testSuite!!.packageName, "com.example.testsuite") + assertTrue(testSuite.testCases.isEmpty()) + } + + @Test + fun testParseSingleTestCase() { + val text = """ + ```kotlin + import org.junit.jupiter.api.Test + + class SingleTestCaseClass { + @Test + fun singleTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) assertNotNull(testSuite) - assert(testSuite.imports.contains("import org.mockito.Mockito.*")) - assert(testSuite.imports.contains("import org.test.Message as TestMessage")) - assert(testSuite.imports.contains("import org.mockito.kotlin.mock")) - assert(testSuite.testCases[0].name == "compileTestCases_AllCompilableTest") - assert(testSuite.testCases[1].name == "compileTestCases_NoneCompilableTest") - assert(testSuite.testCases[2].name == "compileTestCases_SomeCompilableTest") - assert(testSuite.testCases[3].name == "compileTestCases_EmptyTestCasesTest") - assert(testSuite.testCases[4].name == "compileTestCases_omg") - assert(testSuite.testCases[4].expectedException.isNotBlank()) + assertEquals(1, testSuite!!.testCases.size) + assertEquals("singleTestCase", testSuite.testCases[0].name) + } + + @Test + fun testParseTwoTestCases() { + val text = """ + ```kotlin + import org.junit.jupiter.api.Test + + class TwoTestCasesClass { + @Test + fun firstTestCase() { + // Test case implementation + } + + @Test + fun secondTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = + KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + assertNotNull(testSuite) + assertEquals(2, testSuite!!.testCases.size) + assertEquals("firstTestCase", testSuite.testCases[0].name) + assertEquals("secondTestCase", testSuite.testCases[1].name) + } + + @Test + fun testParseTwoTestCasesWithDifferentPackage() { + val code1 = """ + ```kotlin + package org.pkg1 + + import org.junit.jupiter.api.Test + + class TestCasesClass1 { + @Test + fun firstTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val code2 = """ + ```kotlin + package org.pkg2 + + import org.junit.jupiter.api.Test + + class 2TestCasesClass { + @Test + fun firstTestCase() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = KotlinJUnitTestSuiteParser("", JUnitVersion.JUnit5, testBodyPrinter) + + // packageName will be set to 'org.pkg1' + val testSuite1 = parser.parseTestSuite(code1) + + val testSuite2 = parser.parseTestSuite(code2) + + assertNotNull(testSuite1) + assertNotNull(testSuite2) + assertEquals("org.pkg1", testSuite1!!.packageName) + assertEquals("org.pkg2", testSuite2!!.packageName) } } diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt index 007bdbff7..087485827 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt @@ -14,6 +14,7 @@ import org.jetbrains.research.testspark.core.utils.javaImportPattern import org.jetbrains.research.testspark.core.utils.javaPackagePattern import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper +import org.jetbrains.research.testspark.langwrappers.strategies.JavaKotlinClassTextExtractor class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { override val name: String get() = psiClass.name ?: "" @@ -33,29 +34,12 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { override val containingFile: PsiFile get() = psiClass.containingFile override val fullText: String - get() { - var fullText = "" - val fileText = psiClass.containingFile.text - - // get package - javaPackagePattern.findAll(fileText).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n\n" - } - - // get imports - javaImportPattern.findAll(fileText).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n" - } - - // Add class code - fullText += psiClass.text - - return fullText - } + get() = JavaKotlinClassTextExtractor().extract( + psiClass.containingFile, + psiClass.text, + javaPackagePattern, + javaImportPattern, + ) override val classType: ClassType get() { @@ -68,6 +52,8 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { return ClassType.CLASS } + override val rBrace: Int? = psiClass.rBrace?.textRange?.startOffset + override fun searchSubclasses(project: Project): Collection { val scope = GlobalSearchScope.projectScope(project) val query = ClassInheritorsSearch.search(psiClass, scope, false) diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt index 8b513deda..f6f132a29 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt @@ -4,23 +4,27 @@ import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.editor.Caret +import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.project.Project import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiClass import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile +import com.intellij.psi.PsiJavaFile import com.intellij.psi.PsiMethod import com.intellij.psi.util.PsiTreeUtil import com.intellij.psi.util.PsiTypesUtil -import org.jetbrains.research.testspark.langwrappers.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType +import org.jetbrains.research.testspark.langwrappers.CodeTypeDisplayName import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { - override val language: Language get() = Language.Java + override val language: SupportedLanguage get() = SupportedLanguage.Java private val log = Logger.getInstance(this::class.java) @@ -63,7 +67,7 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { return null } - override fun getSurroundingLine(caretOffset: Int): Int? { + override fun getSurroundingLineNumber(caretOffset: Int): Int? { val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null val selectedLine = doc.getLineNumber(caretOffset) @@ -84,7 +88,7 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { project: Project, classesToTest: MutableList, caretOffset: Int, - maxPolymorphismDepth: Int, // check if cut has any non-java super class + maxPolymorphismDepth: Int, ) { val cutPsiClass = getSurroundingClass(caretOffset)!! var currentPsiClass = cutPsiClass @@ -138,39 +142,44 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { } override fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet { - val interestingPsiClasses = cut.getInterestingPsiClassesWithQualifiedNames(psiMethod) + // The cut is always not null for Java, because all functions are always inside the class + val interestingPsiClasses = cut!!.getInterestingPsiClassesWithQualifiedNames(psiMethod) log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}") return interestingPsiClasses } - override fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? { - val result: ArrayList = arrayListOf() + override fun getCurrentListOfCodeTypes(e: AnActionEvent): List { + val result: ArrayList = arrayListOf() val caret: Caret = - e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result.toArray() + e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result val javaPsiClassWrapped = getSurroundingClass(caret.offset) as JavaPsiClassWrapper? val javaPsiMethodWrapped = getSurroundingMethod(caret.offset) as JavaPsiMethodWrapper? - val line: Int? = getSurroundingLine(caret.offset) - - javaPsiClassWrapped?.let { result.add(getClassHTMLDisplayName(it)) } - javaPsiMethodWrapped?.let { result.add(getMethodHTMLDisplayName(it)) } - line?.let { result.add(getLineHTMLDisplayName(it)) } - - if (javaPsiClassWrapped != null && javaPsiMethodWrapped != null) { - log.info( - "The test can be generated for: \n " + - " 1) Class ${javaPsiClassWrapped.qualifiedName} \n" + - " 2) Method ${javaPsiMethodWrapped.name} \n" + - " 3) Line $line", - ) - } + val line: Int? = getSurroundingLineNumber(caret.offset) + + javaPsiClassWrapped?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } + javaPsiMethodWrapped?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } + line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) } + + log.info( + "The test can be generated for: \n " + + " 1) Class ${javaPsiClassWrapped?.qualifiedName ?: "no class"} \n" + + " 2) Method ${javaPsiMethodWrapped?.name ?: "no method"} \n" + + " 3) Line $line", + ) - return result.toArray() + return result } + override fun getPackageName() = (psiFile as PsiJavaFile).packageName + + override fun getModuleFromPsiFile() = ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!! + + override fun getDocumentFromPsiFile() = psiFile.fileDocument + override fun getLineHTMLDisplayName(line: Int) = "line $line" override fun getClassHTMLDisplayName(psiClass: PsiClassWrapper): String = diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt index 8ac75755c..50cc12f0f 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt @@ -21,6 +21,7 @@ import org.jetbrains.research.testspark.core.utils.kotlinImportPattern import org.jetbrains.research.testspark.core.utils.kotlinPackagePattern import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper +import org.jetbrains.research.testspark.langwrappers.strategies.JavaKotlinClassTextExtractor class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWrapper { override val name: String get() = psiClass.name ?: "" @@ -61,29 +62,12 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra override val containingFile: PsiFile get() = psiClass.containingFile override val fullText: String - get() { - var fullText = "" - val fileText = psiClass.containingFile.text - - // get package - kotlinPackagePattern.findAll(fileText, 0).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n\n" - } - - // get imports - kotlinImportPattern.findAll(fileText, 0).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n" - } - - // Add class code - fullText += psiClass.text - - return fullText - } + get() = JavaKotlinClassTextExtractor().extract( + psiClass.containingFile, + psiClass.text, + kotlinPackagePattern, + kotlinImportPattern, + ) override val classType: ClassType get() { @@ -97,6 +81,8 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra } } + override val rBrace: Int? = psiClass.body?.rBrace?.textRange?.startOffset + override fun searchSubclasses(project: Project): Collection { val scope = GlobalSearchScope.projectScope(project) val lightClass = psiClass.toLightClass() @@ -116,11 +102,9 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra method.psiFunction.valueParameters.forEach { parameter -> val typeReference = parameter.typeReference - if (typeReference != null) { - val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) - if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { - interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) - } + val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) + if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { + interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) } } diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt index 13749bd35..fd8a78a1b 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt @@ -4,31 +4,26 @@ import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.editor.Caret +import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.project.Project import com.intellij.openapi.util.TextRange -import com.intellij.psi.PsiClass import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiFile import com.intellij.psi.util.parentOfType -import org.jetbrains.kotlin.asJava.toLightClass -import org.jetbrains.kotlin.descriptors.ClassDescriptor -import org.jetbrains.kotlin.idea.base.psi.kotlinFqName -import org.jetbrains.kotlin.idea.caches.resolve.analyze -import org.jetbrains.kotlin.psi.KtClass import org.jetbrains.kotlin.psi.KtClassOrObject +import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtFunction -import org.jetbrains.kotlin.psi.KtTypeReference -import org.jetbrains.kotlin.resolve.BindingContext -import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils -import org.jetbrains.kotlin.resolve.lazy.BodyResolveMode -import org.jetbrains.research.testspark.langwrappers.Language +import org.jetbrains.kotlin.psi.KtPsiUtil +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType +import org.jetbrains.research.testspark.langwrappers.CodeTypeDisplayName import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper -class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { +class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { - override val language: Language get() = Language.Kotlin + override val language: SupportedLanguage get() = SupportedLanguage.Kotlin private val log = Logger.getInstance(this::class.java) @@ -66,7 +61,7 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { return null } - override fun getSurroundingLine(caretOffset: Int): Int? { + override fun getSurroundingLineNumber(caretOffset: Int): Int? { val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null val selectedLine = doc.getLineNumber(caretOffset) @@ -85,9 +80,10 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { project: Project, classesToTest: MutableList, caretOffset: Int, - maxPolymorphismDepth: Int, // check if cut has any non-java super class + maxPolymorphismDepth: Int, ) { - val cutPsiClass = getSurroundingClass(caretOffset)!! + val cutPsiClass = getSurroundingClass(caretOffset) ?: return + // will be null for the top level function var currentPsiClass = cutPsiClass for (index in 0 until maxPolymorphismDepth) { if (!classesToTest.contains(currentPsiClass)) { @@ -116,19 +112,13 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { repeat(maxInputParamsDepth) { val tempListOfClasses = mutableSetOf() - currentLevelClasses.forEach { classIt -> classIt.methods.forEach { methodIt -> (methodIt as KotlinPsiMethodWrapper).parameterList?.parameters?.forEach { paramIt -> - val typeRef = paramIt.typeReference - if (typeRef != null) { - resolveClassInType(typeRef)?.let { psiClass -> - if (psiClass.kotlinFqName != null) { - KotlinPsiClassWrapper(psiClass as KtClass).let { - if (!it.qualifiedName.startsWith("kotlin.")) { - interestingPsiClasses.add(it) - } - } + KtPsiUtil.getClassIfParameterIsProperty(paramIt)?.let { typeIt -> + KotlinPsiClassWrapper(typeIt).let { + if (!it.qualifiedName.startsWith("kotlin.")) { + interestingPsiClasses.add(it) } } } @@ -143,39 +133,45 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { } override fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet { - val interestingPsiClasses = cut.getInterestingPsiClassesWithQualifiedNames(psiMethod) + val interestingPsiClasses = + cut?.getInterestingPsiClassesWithQualifiedNames(psiMethod) + ?: (psiMethod as KotlinPsiMethodWrapper).getInterestingPsiClassesWithQualifiedNames() log.info("There are ${interestingPsiClasses.size} interesting psi classes from method ${psiMethod.methodDescriptor}") return interestingPsiClasses } - override fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? { - val result: ArrayList = arrayListOf() + override fun getCurrentListOfCodeTypes(e: AnActionEvent): List { + val result: ArrayList = arrayListOf() val caret: Caret = - e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result.toArray() + e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret ?: return result val ktClass = getSurroundingClass(caret.offset) val ktFunction = getSurroundingMethod(caret.offset) - val line: Int? = getSurroundingLine(caret.offset)?.plus(1) - - ktClass?.let { result.add(getClassHTMLDisplayName(it)) } - ktFunction?.let { result.add(getMethodHTMLDisplayName(it)) } - line?.let { result.add(getLineHTMLDisplayName(it)) } - - if (ktClass != null && ktFunction != null) { - log.info( - "The test can be generated for: \n " + - " 1) Class ${ktClass.qualifiedName} \n" + - " 2) Method ${ktFunction.name} \n" + - " 3) Line $line", - ) - } + val line: Int? = getSurroundingLineNumber(caret.offset)?.plus(1) + + ktClass?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } + ktFunction?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } + line?.let { result.add(CodeType.LINE to getLineHTMLDisplayName(it)) } - return result.toArray() + log.info( + "The test can be generated for: \n " + + " 1) Class ${ktClass?.qualifiedName ?: "no class"} \n" + + " 2) Method ${ktFunction?.name ?: "no method"} \n" + + " 3) Line $line", + ) + + return result } + override fun getPackageName() = (psiFile as KtFile).packageFqName.asString() + + override fun getModuleFromPsiFile() = ModuleUtilCore.findModuleForFile(psiFile.virtualFile, psiFile.project)!! + + override fun getDocumentFromPsiFile() = psiFile.fileDocument + override fun getLineHTMLDisplayName(line: Int) = "line $line" override fun getClassHTMLDisplayName(psiClass: PsiClassWrapper): String = @@ -184,18 +180,11 @@ class KotlinPsiHelper(var psiFile: PsiFile) : PsiHelper { override fun getMethodHTMLDisplayName(psiMethod: PsiMethodWrapper): String { psiMethod as KotlinPsiMethodWrapper return when { - psiMethod.isTopLevelFunction -> "top-level function" + psiMethod.isTopLevelFunction -> "top-level function ${psiMethod.name}" psiMethod.isSecondaryConstructor -> "secondary constructor" psiMethod.isPrimaryConstructor -> "constructor" psiMethod.isDefaultMethod -> "default method ${psiMethod.name}" else -> "method ${psiMethod.name}" } } - - private fun resolveClassInType(typeReference: KtTypeReference): PsiClass? { - val context = typeReference.analyze(BodyResolveMode.PARTIAL) - val type = context[BindingContext.TYPE, typeReference] ?: return null - val classDescriptor = type.constructor.declarationDescriptor as? ClassDescriptor ?: return null - return (DescriptorToSourceUtils.getSourceFromDescriptor(classDescriptor) as? KtClass)?.toLightClass() - } } diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt index a142aaaa8..c993fd808 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt @@ -68,6 +68,26 @@ class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper { return lineNumber in startLine..endLine } + /** + * Returns a set of `PsiClassWrapper` instances for non-standard Kotlin classes referenced by the + * parameters of the current function. + * + * @return A mutable set of `PsiClassWrapper` instances representing non-standard Kotlin classes. + */ + fun getInterestingPsiClassesWithQualifiedNames(): MutableSet { + val interestingPsiClasses = mutableSetOf() + + psiFunction.valueParameters.forEach { parameter -> + val typeReference = parameter.typeReference + val psiClass = PsiTreeUtil.getParentOfType(typeReference, KtClass::class.java) + if (psiClass != null && psiClass.fqName != null && !psiClass.fqName.toString().startsWith("kotlin.")) { + interestingPsiClasses.add(KotlinPsiClassWrapper(psiClass)) + } + } + + return interestingPsiClasses + } + /** * Generates the return descriptor for a method. * diff --git a/langwrappers/build.gradle.kts b/langwrappers/build.gradle.kts index 74ec82496..317debb35 100644 --- a/langwrappers/build.gradle.kts +++ b/langwrappers/build.gradle.kts @@ -5,7 +5,6 @@ plugins { repositories { mavenCentral() - // Add any other repositories you need } dependencies { @@ -17,7 +16,6 @@ dependencies { intellij { rootProject.properties["platformVersion"]?.let { version.set(it.toString()) } plugins.set(listOf("java")) - downloadSources.set(true) } tasks.named("verifyPlugin") { enabled = false } diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt index f61dc7a1b..0aa5dfd0f 100644 --- a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt @@ -1,11 +1,15 @@ package org.jetbrains.research.testspark.langwrappers import com.intellij.openapi.actionSystem.AnActionEvent +import com.intellij.openapi.editor.Document import com.intellij.openapi.project.Project import com.intellij.openapi.vfs.VirtualFile import com.intellij.psi.PsiFile import org.jetbrains.research.testspark.core.data.ClassType -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType + +typealias CodeTypeDisplayName = Pair /** * Interface representing a wrapper for PSI methods, @@ -40,12 +44,14 @@ interface PsiMethodWrapper { * @property name The name of a class * @property qualifiedName The qualified name of the class. * @property text The text of the class. - * @property fullText The source code of the class (with package and imports). - * @property virtualFile - * @property containingFile File where the method is located - * @property superClass The super class of the class * @property methods All methods in the class * @property allMethods All methods in the class and all its superclasses + * @property superClass The super class of the class + * @property virtualFile Virtual file where the class is located + * @property containingFile File where the method is located + * @property fullText The source code of the class (with package and imports). + * @property classType The type of the class + * @property rBrace The offset of the closing brace * */ interface PsiClassWrapper { val name: String @@ -58,6 +64,7 @@ interface PsiClassWrapper { val containingFile: PsiFile val fullText: String val classType: ClassType + val rBrace: Int? /** * Searches for subclasses of the current class within the given project. @@ -81,7 +88,7 @@ interface PsiClassWrapper { * handling the PSI (Program Structure Interface) for different languages. */ interface PsiHelper { - val language: Language + val language: SupportedLanguage /** * Returns the surrounding PsiClass object based on the caret position within the specified PsiFile. @@ -107,7 +114,7 @@ interface PsiHelper { * @param caretOffset The caret offset within the PSI file. * @return The line number of the selected line, otherwise null. */ - fun getSurroundingLine(caretOffset: Int): Int? + fun getSurroundingLineNumber(caretOffset: Int): Int? /** * Retrieves a set of interesting PsiClasses based on a given project, @@ -133,7 +140,7 @@ interface PsiHelper { * @return A mutable set of interesting PsiClasses. */ fun getInterestingPsiClassesWithQualifiedNames( - cut: PsiClassWrapper, + cut: PsiClassWrapper?, psiMethod: PsiMethodWrapper, ): MutableSet @@ -145,7 +152,7 @@ interface PsiHelper { * The array contains the class display name, method display name (if present), and the line number (if present). * The line number is prefixed with "Line". */ - fun getCurrentListOfCodeTypes(e: AnActionEvent): Array<*>? + fun getCurrentListOfCodeTypes(e: AnActionEvent): List /** * Helper for generating method descriptors for methods. @@ -160,8 +167,8 @@ interface PsiHelper { * * @param project The project in which to collect classes to test. * @param classesToTest The list of classes to test. - * @param psiHelper The PSI helper instance to use for collecting classes. * @param caretOffset The caret offset in the file. + * @param maxPolymorphismDepth Check if cut has any user-defined superclass */ fun collectClassesToTest( project: Project, @@ -170,6 +177,21 @@ interface PsiHelper { maxPolymorphismDepth: Int, ) + /** + * Get the package name of the file. + */ + fun getPackageName(): String + + /** + * Get the module of the file. + */ + fun getModuleFromPsiFile(): com.intellij.openapi.module.Module + + /** + * Get the module of the file. + */ + fun getDocumentFromPsiFile(): Document? + /** * Gets the display line number. * This is used when displaying the name of a method in the GenerateTestsActionMethod menu entry. diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt index 5a0a96fbc..a6f342882 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt @@ -17,6 +17,7 @@ import org.jetbrains.research.testspark.actions.llm.LLMSetupPanelFactory import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.display.TestSparkIcons import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider @@ -76,7 +77,6 @@ class TestSparkAction : AnAction() { if (psiHelper == null) { // TODO exception } - e.presentation.isEnabled = psiHelper!!.getCurrentListOfCodeTypes(e) != null } /** @@ -111,18 +111,18 @@ class TestSparkAction : AnAction() { return psiHelper!! } - private val codeTypes = psiHelper.getCurrentListOfCodeTypes(e)!! + private val codeTypes = psiHelper.getCurrentListOfCodeTypes(e) private val caretOffset: Int = e.dataContext.getData(CommonDataKeys.CARET)?.caretModel?.primaryCaret!!.offset private val fileUrl = e.dataContext.getData(CommonDataKeys.VIRTUAL_FILE)!!.presentableUrl - private val codeTypeButtons: MutableList = mutableListOf() + private val codeTypeButtons: MutableList> = mutableListOf() private val codeTypeButtonGroup = ButtonGroup() private val nextButton = JButton(PluginLabelsBundle.get("next")) private val cardLayout = CardLayout() private val llmSetupPanelFactory = LLMSetupPanelFactory(e, project) - private val llmSampleSelectorFactory = LLMSampleSelectorFactory(project) + private val llmSampleSelectorFactory = LLMSampleSelectorFactory(project, psiHelper.language) private val evoSuitePanelFactory = EvoSuitePanelFactory(project) init { @@ -198,16 +198,19 @@ class TestSparkAction : AnAction() { testGeneratorPanel.add(llmButton) testGeneratorPanel.add(evoSuiteButton) - for (codeType in codeTypes) { - val button = JRadioButton(codeType as String) - codeTypeButtons.add(button) + for ((codeType, codeTypeName) in codeTypes) { + val button = JRadioButton(codeTypeName) + codeTypeButtons.add(codeType to button) codeTypeButtonGroup.add(button) } val codesToTestPanel = JPanel() codesToTestPanel.add(JLabel("Select the code type:")) - if (codeTypeButtons.size == 1) codeTypeButtons[0].isSelected = true - for (button in codeTypeButtons) codesToTestPanel.add(button) + if (codeTypeButtons.size == 1) { + // A single button is selected by default + codeTypeButtons[0].second.isSelected = true + } + for ((_, button) in codeTypeButtons) codesToTestPanel.add(button) val middlePanel = FormBuilder.createFormBuilder() .setFormLeftIndent(10) @@ -253,7 +256,7 @@ class TestSparkAction : AnAction() { updateNextButton() } - for (button in codeTypeButtons) { + for ((_, button) in codeTypeButtons) { button.addActionListener { llmSetupPanelFactory.setPromptEditorType(button.text) updateNextButton() @@ -330,33 +333,36 @@ class TestSparkAction : AnAction() { if (!testGenerationController.isGeneratorRunning(project)) { val testSamplesCode = llmSampleSelectorFactory.getTestSamplesCode() - if (codeTypeButtons[0].isSelected) { - tool.generateTestsForClass( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) - } else if (codeTypeButtons[1].isSelected) { - tool.generateTestsForMethod( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) - } else if (codeTypeButtons[2].isSelected) { - tool.generateTestsForLine( - project, - psiHelper, - caretOffset, - fileUrl, - testSamplesCode, - testGenerationController, - ) + for ((codeType, button) in codeTypeButtons) { + if (button.isSelected) { + when (codeType) { + CodeType.CLASS -> tool.generateTestsForClass( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + CodeType.METHOD -> tool.generateTestsForMethod( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + CodeType.LINE -> tool.generateTestsForLine( + project, + psiHelper, + caretOffset, + fileUrl, + testSamplesCode, + testGenerationController, + ) + } + break + } } } @@ -376,10 +382,7 @@ class TestSparkAction : AnAction() { */ private fun updateNextButton() { val isTestGeneratorButtonGroupSelected = llmButton.isSelected || evoSuiteButton.isSelected - var isCodeTypeButtonGroupSelected = false - for (button in codeTypeButtons) { - isCodeTypeButtonGroupSelected = isCodeTypeButtonGroupSelected || button.isSelected - } + val isCodeTypeButtonGroupSelected = codeTypeButtons.any { it.second.isSelected } nextButton.isEnabled = isTestGeneratorButtonGroupSelected && isCodeTypeButtonGroupSelected if ((llmButton.isSelected && !llmSettingsState.llmSetupCheckBoxSelected && !llmSettingsState.provideTestSamplesCheckBoxSelected) || diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt index b57ee8d81..b6b77a0ff 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSampleSelectorFactory.kt @@ -4,6 +4,7 @@ import com.intellij.openapi.project.Project import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.actions.template.PanelFactory import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.helpers.LLMTestSampleHelper import java.awt.Font import javax.swing.ButtonGroup @@ -12,7 +13,7 @@ import javax.swing.JLabel import javax.swing.JPanel import javax.swing.JRadioButton -class LLMSampleSelectorFactory(private val project: Project) : PanelFactory { +class LLMSampleSelectorFactory(private val project: Project, private val language: SupportedLanguage) : PanelFactory { // init components private val selectionTypeButtons: MutableList = mutableListOf( JRadioButton(PluginLabelsBundle.get("provideTestSample")), @@ -128,7 +129,7 @@ class LLMSampleSelectorFactory(private val project: Project) : PanelFactory { } addButton.addActionListener { - val testSamplePanelFactory = TestSamplePanelFactory(project, middlePanel, testNames, initialTestCodes) + val testSamplePanelFactory = TestSamplePanelFactory(project, middlePanel, testNames, initialTestCodes, language) testSamplePanelFactories.add(testSamplePanelFactory) val testSamplePanel = testSamplePanelFactory.getTestSamplePanel() val codeScrollPanel = testSamplePanelFactory.getCodeScrollPanel() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt index 69d5db9f3..8afe31fc8 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelFactory.kt @@ -34,7 +34,7 @@ class LLMSetupPanelFactory(e: AnActionEvent, private val project: Project) : Pan private val defaultModulesArray = arrayOf("") private var modelSelector = ComboBox(defaultModulesArray) private var llmUserTokenField = JTextField(30) - private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName)) + private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName)) private val backLlmButton = JButton(PluginLabelsBundle.get("back")) private val okLlmButton = JButton(PluginLabelsBundle.get("next")) private val junitSelector = JUnitCombobox(e) @@ -142,6 +142,10 @@ class LLMSetupPanelFactory(e: AnActionEvent, private val project: Project) : Pan llmSettingsState.grazieToken = llmPlatforms[index].token llmSettingsState.grazieModel = llmPlatforms[index].model } + if (llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + llmSettingsState.huggingFaceToken = llmPlatforms[index].token + llmSettingsState.huggingFaceModel = llmPlatforms[index].model + } } llmSettingsState.junitVersion = junitSelector.selectedItem!! as JUnitVersion diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt index 97cf6d49a..251a45f27 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/TestSamplePanelFactory.kt @@ -10,6 +10,7 @@ import com.intellij.openapi.ui.ComboBox import com.intellij.ui.LanguageTextField import com.intellij.ui.components.JBScrollPane import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.display.IconButtonCreator import org.jetbrains.research.testspark.display.ModifiedLinesGetter import org.jetbrains.research.testspark.display.TestCaseDocumentCreator @@ -25,11 +26,12 @@ class TestSamplePanelFactory( private val middlePanel: JPanel, private val testNames: MutableList, private val initialTestCodes: MutableList, + private val language: SupportedLanguage, ) { // init components private val currentTestCodes = initialTestCodes.toMutableList() private val languageTextField = LanguageTextField( - Language.findLanguageByID("JAVA"), + Language.findLanguageByID(language.languageId), project, initialTestCodes[0], TestCaseDocumentCreator("TestSample"), diff --git a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt index b8b0654d3..499abf1c1 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt @@ -18,7 +18,8 @@ import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.ProjectContext import org.jetbrains.research.testspark.data.llm.JsonEncoding @@ -26,6 +27,7 @@ import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider import org.jetbrains.research.testspark.progress.HeadlessProgressIndicator import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.services.PluginSettingsService +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.Llm @@ -172,6 +174,12 @@ class TestSparkStarter : ApplicationStarter { // Start test generation val indicator = HeadlessProgressIndicator() val errorMonitor = DefaultErrorMonitor() + val testCompiler = TestCompilerFactory.create( + project, + settingsState.junitVersion, + psiHelper.language, + projectSDKPath.toString(), + ) val uiContext = llmProcessManager.runTestGenerator( indicator, FragmentToTestData(CodeType.CLASS), @@ -192,6 +200,7 @@ class TestSparkStarter : ApplicationStarter { classPath, projectContext, projectSDKPath, + testCompiler, ) } else { println("[TestSpark Starter] Test generation failed") @@ -237,6 +246,7 @@ class TestSparkStarter : ApplicationStarter { classPath: String, projectContext: ProjectContext, projectSDKPath: Path, + testCompiler: TestCompiler, ) { val targetDirectory = "$out${File.separator}${packageList.joinToString(File.separator)}" println("Run tests in $targetDirectory") @@ -246,6 +256,7 @@ class TestSparkStarter : ApplicationStarter { var testcaseName = it.nameWithoutExtension.removePrefix("Generated") testcaseName = testcaseName[0].lowercaseChar() + testcaseName.substring(1) // The current test is compiled and is ready to run jacoco + val testExecutionError = TestProcessor(project, projectSDKPath).createXmlFromJacoco( it.nameWithoutExtension, "$targetDirectory${File.separator}jacoco-${it.nameWithoutExtension}", @@ -254,6 +265,7 @@ class TestSparkStarter : ApplicationStarter { packageList.joinToString("."), out, projectContext, + testCompiler, ) // Saving exception (if exists) thrown during the test execution saveException(testcaseName, targetDirectory, testExecutionError) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt deleted file mode 100644 index 8e91aded4..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/data/CodeType.kt +++ /dev/null @@ -1,8 +0,0 @@ -package org.jetbrains.research.testspark.data - -/** -* Enum class, which contains all code elements for which it is possible to request test generation. -*/ -enum class CodeType { - CLASS, METHOD, LINE -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt index 0cf79dddb..3c289bb11 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/data/FragmentToTestData.kt @@ -1,5 +1,7 @@ package org.jetbrains.research.testspark.data +import org.jetbrains.research.testspark.core.test.data.CodeType + /** * Data about test objects that require test generators. */ diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt index f17e8720b..99b0ec5ab 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TestCasePanelFactory.kt @@ -25,17 +25,20 @@ import org.jetbrains.research.testspark.core.data.Report import org.jetbrains.research.testspark.core.data.TestCase import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.data.llm.JsonEncoding import org.jetbrains.research.testspark.display.custom.IJProgressIndicator -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper import org.jetbrains.research.testspark.helpers.LLMHelper import org.jetbrains.research.testspark.helpers.ReportHelper import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.services.TestCaseDisplayService import org.jetbrains.research.testspark.services.TestsExecutionResultService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestClassCodeAnalyzerFactory +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.test.JUnitTestSuitePresenter @@ -58,7 +61,7 @@ import javax.swing.border.MatteBorder class TestCasePanelFactory( private val project: Project, - private val language: org.jetbrains.research.testspark.core.utils.Language, + private val language: SupportedLanguage, private val testCase: TestCase, editor: Editor, private val checkbox: JCheckBox, @@ -193,7 +196,10 @@ class TestCasePanelFactory( val clipboard: Clipboard = Toolkit.getDefaultToolkit().systemClipboard clipboard.setContents( StringSelection( - project.service().getEditor(testCase.testName)!!.document.text, + when (language) { + SupportedLanguage.Kotlin -> project.service().getEditor(testCase.testName)!!.document.text + SupportedLanguage.Java -> project.service().getEditor(testCase.testName)!!.document.text + }, ), null, ) @@ -386,7 +392,10 @@ class TestCasePanelFactory( } ReportHelper.updateTestCase(project, report, testCase) - project.service().updateUI() + when (language) { + SupportedLanguage.Kotlin -> project.service().updateUI() + SupportedLanguage.Java -> project.service().updateUI() + } } /** @@ -454,12 +463,12 @@ class TestCasePanelFactory( } private fun addTest(testSuite: TestSuiteGeneratedByLLM) { - val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput) + val testSuitePresenter = JUnitTestSuitePresenter(project, uiContext!!.testGenerationOutput, language) WriteCommandAction.runWriteCommandAction(project) { uiContext.errorMonitor.clear() val code = testSuitePresenter.toString(testSuite) - testCase.testName = JavaClassBuilderHelper.getTestMethodNameFromClassWithTestCase(testCase.testName, code) + testCase.testName = TestClassCodeAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, code) testCase.testCode = code // update numbers @@ -517,15 +526,24 @@ class TestCasePanelFactory( private fun runTest(indicator: CustomProgressIndicator) { indicator.setText("Executing ${testCase.testName}") + val fileName = TestClassCodeAnalyzerFactory.create(language).getFileNameFromTestCaseCode(testCase.testName) + + val testCompiler = TestCompilerFactory.create( + project, + llmSettingsState.junitVersion, + language, + ) + val newTestCase = TestProcessor(project) .processNewTestCase( - "${JavaClassBuilderHelper.getClassFromTestCaseCode(testCase.testCode)}.java", + fileName, testCase.id, testCase.testName, testCase.testCode, - uiContext!!.testGenerationOutput.packageLine, + uiContext!!.testGenerationOutput.packageName, uiContext.testGenerationOutput.resultPath, uiContext.projectContext, + testCompiler, ) testCase.coveredLines = newTestCase.coveredLines @@ -585,13 +603,23 @@ class TestCasePanelFactory( */ private fun remove() { // Remove the test case from the cache - project.service().removeTestCase(testCase.testName) + when (language) { + SupportedLanguage.Kotlin -> project.service().removeTestCase(testCase.testName) + + SupportedLanguage.Java -> project.service().removeTestCase(testCase.testName) + } runTestButton.isEnabled = false isRemoved = true ReportHelper.removeTestCase(project, report, testCase) - project.service().updateUI() + when (language) { + SupportedLanguage.Kotlin -> project.service() + .updateUI() + + SupportedLanguage.Java -> project.service() + .updateUI() + } } /** @@ -663,8 +691,7 @@ class TestCasePanelFactory( * Updates the current test case with the specified test name and test code. */ private fun updateTestCaseInformation() { - testCase.testName = - JavaClassBuilderHelper.getTestMethodNameFromClassWithTestCase(testCase.testName, languageTextField.document.text) + testCase.testName = TestClassCodeAnalyzerFactory.create(language).extractFirstTestMethodName(testCase.testName, languageTextField.document.text) testCase.testCode = languageTextField.document.text } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt index 31cc7b9a6..1a5938be1 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt @@ -1,197 +1,13 @@ package org.jetbrains.research.testspark.display -import com.intellij.openapi.components.service -import com.intellij.openapi.progress.ProgressIndicator -import com.intellij.openapi.progress.ProgressManager -import com.intellij.openapi.progress.Task import com.intellij.openapi.project.Project -import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle -import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle -import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.display.custom.IJProgressIndicator -import org.jetbrains.research.testspark.services.TestCaseDisplayService -import java.awt.Dimension -import java.util.LinkedList -import java.util.Queue -import javax.swing.Box -import javax.swing.BoxLayout -import javax.swing.JButton -import javax.swing.JCheckBox -import javax.swing.JLabel -import javax.swing.JOptionPane -import javax.swing.JPanel +import org.jetbrains.research.testspark.core.test.SupportedLanguage class TopButtonsPanelFactory(private val project: Project) { - private var runAllButton: JButton = createRunAllTestButton() - private var selectAllButton: JButton = - IconButtonCreator.getButton(TestSparkIcons.selectAll, PluginLabelsBundle.get("selectAllTip")) - private var unselectAllButton: JButton = - IconButtonCreator.getButton(TestSparkIcons.unselectAll, PluginLabelsBundle.get("unselectAllTip")) - private var removeAllButton: JButton = - IconButtonCreator.getButton(TestSparkIcons.removeAll, PluginLabelsBundle.get("removeAllTip")) - - private var testsSelectedText: String = "${PluginLabelsBundle.get("testsSelected")}: %d/%d" - private var testsSelectedLabel: JLabel = JLabel(testsSelectedText) - - private val testsPassedText: String = "${PluginLabelsBundle.get("testsPassed")}: %d/%d" - private var testsPassedLabel: JLabel = JLabel(testsPassedText) - - private val testCasePanelFactories = arrayListOf() - - fun getPanel(): JPanel { - val panel = JPanel() - panel.layout = BoxLayout(panel, BoxLayout.X_AXIS) - panel.preferredSize = Dimension(0, 30) - panel.add(Box.createRigidArea(Dimension(10, 0))) - panel.add(testsPassedLabel) - panel.add(Box.createRigidArea(Dimension(10, 0))) - panel.add(testsSelectedLabel) - panel.add(Box.createHorizontalGlue()) - panel.add(runAllButton) - panel.add(selectAllButton) - panel.add(unselectAllButton) - panel.add(removeAllButton) - - selectAllButton.addActionListener { toggleAllCheckboxes(true) } - unselectAllButton.addActionListener { toggleAllCheckboxes(false) } - removeAllButton.addActionListener { removeAllTestCases() } - runAllButton.addActionListener { runAllTestCases() } - - return panel - } - - /** - * Updates the labels. - */ - fun updateTopLabels() { - var numberOfPassedTests = 0 - for (testCasePanelFactory in testCasePanelFactories) { - if (testCasePanelFactory.isRemoved()) continue - val error = testCasePanelFactory.getError() - if ((error is String) && error.isEmpty()) { - numberOfPassedTests++ - } - } - testsSelectedLabel.text = String.format( - testsSelectedText, - project.service().getTestsSelected(), - project.service().getTestCasePanels().size, - ) - testsPassedLabel.text = - String.format( - testsPassedText, - numberOfPassedTests, - project.service().getTestCasePanels().size, - ) - runAllButton.isEnabled = false - for (testCasePanelFactory in testCasePanelFactories) { - runAllButton.isEnabled = runAllButton.isEnabled || testCasePanelFactory.isRunEnabled() - } - } - - /** - * Sets the array of TestCasePanelFactory objects. - * - * @param testCasePanelFactories The ArrayList containing the TestCasePanelFactory objects to be set. - */ - fun setTestCasePanelFactoriesArray(testCasePanelFactories: ArrayList) { - this.testCasePanelFactories.addAll(testCasePanelFactories) - } - - /** - * Toggles check boxes so that they are either all selected or all not selected, - * depending on the provided parameter. - * - * @param selected whether the checkboxes have to be selected or not - */ - private fun toggleAllCheckboxes(selected: Boolean) { - project.service().getTestCasePanels().forEach { (_, jPanel) -> - val checkBox = jPanel.getComponent(0) as JCheckBox - checkBox.isSelected = selected - } - project.service() - .setTestsSelected(if (selected) project.service().getTestCasePanels().size else 0) - } - - /** - * Removes all test cases from the cache and tool window UI. - */ - private fun removeAllTestCases() { - // Ask the user for the confirmation - val choice = JOptionPane.showConfirmDialog( - null, - PluginMessagesBundle.get("removeAllMessage"), - PluginMessagesBundle.get("confirmationTitle"), - JOptionPane.YES_NO_OPTION, - JOptionPane.QUESTION_MESSAGE, - ) - - // Cancel the operation if the user did not press "Yes" - if (choice == JOptionPane.NO_OPTION) return - - project.service().clear() - } - - /** - * Executes all test cases. - * - * This method presents a caution message to the user and asks for confirmation before executing the test cases. - * If the user confirms, it iterates through each test case panel factory and runs the corresponding test. - */ - private fun runAllTestCases() { - val choice = JOptionPane.showConfirmDialog( - null, - PluginMessagesBundle.get("runCautionMessage"), - PluginMessagesBundle.get("confirmationTitle"), - JOptionPane.OK_CANCEL_OPTION, - JOptionPane.WARNING_MESSAGE, - ) - - if (choice == JOptionPane.CANCEL_OPTION) return - - runAllButton.isEnabled = false - - // add each test generation task to queue - val tasks: Queue<(CustomProgressIndicator) -> Unit> = LinkedList() - - for (testCasePanelFactory in testCasePanelFactories) { - testCasePanelFactory.addTask(tasks) + fun create(language: SupportedLanguage): TestSuiteView { + return when (language) { + SupportedLanguage.Java -> JavaTestSuiteView(project) + SupportedLanguage.Kotlin -> KotlinTestSuiteView(project) } - // run tasks one after each other - executeTasks(tasks) - } - - private fun executeTasks(tasks: Queue<(CustomProgressIndicator) -> Unit>) { - val nextTask = tasks.poll() - - nextTask?.let { task -> - ProgressManager.getInstance().run(object : Task.Backgroundable(project, "Test execution") { - override fun run(indicator: ProgressIndicator) { - task(IJProgressIndicator(indicator)) - } - - override fun onFinished() { - super.onFinished() - executeTasks(tasks) - } - }) - } - } - - /** - * Creates a JButton for running all tests. - * - * @return a JButton for running all tests - */ - private fun createRunAllTestButton(): JButton { - val runTestButton = JButton(PluginLabelsBundle.get("runAll"), TestSparkIcons.runTest) - runTestButton.isOpaque = false - runTestButton.isContentAreaFilled = false - runTestButton.isBorderPainted = true - return runTestButton - } - - fun clear() { - testCasePanelFactories.clear() } -} +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt index bcad7a834..dee6a2b0e 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/CoverageHelper.kt @@ -16,7 +16,7 @@ import com.intellij.ui.components.JBLabel import com.intellij.ui.components.JBScrollPane import com.intellij.util.ui.FormBuilder import org.jetbrains.research.testspark.services.EvoSuiteSettingsService -import org.jetbrains.research.testspark.services.TestCaseDisplayService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsState import java.awt.Color import java.awt.Dimension @@ -130,7 +130,7 @@ class CoverageHelper( * @param name name of the test to highlight */ private fun highlightInToolwindow(name: String) { - val testCaseDisplayService = project.service() + val testCaseDisplayService = project.service() testCaseDisplayService.highlightTestCase(name) } @@ -141,7 +141,7 @@ class CoverageHelper( * @param map map of mutant operations -> List of names of tests which cover the mutants */ private fun highlightMutantsInToolwindow(mutantName: String, map: HashMap>) { - val testCaseDisplayService = project.service() + val testCaseDisplayService = project.service() testCaseDisplayService.highlightCoveredMutants(map.getOrPut(mutantName) { ArrayList() }) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt deleted file mode 100644 index 977873bdb..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/JavaClassBuilderHelper.kt +++ /dev/null @@ -1,204 +0,0 @@ -package org.jetbrains.research.testspark.helpers - -import com.github.javaparser.ParseProblemException -import com.github.javaparser.StaticJavaParser -import com.github.javaparser.ast.CompilationUnit -import com.github.javaparser.ast.body.MethodDeclaration -import com.github.javaparser.ast.visitor.VoidVisitorAdapter -import com.intellij.lang.java.JavaLanguage -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.openapi.project.Project -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiFile -import com.intellij.psi.PsiFileFactory -import com.intellij.psi.codeStyle.CodeStyleManager -import org.jetbrains.research.testspark.core.data.TestGenerationData -import java.io.File - -object JavaClassBuilderHelper { - /** - * Generates the code for a test class. - * - * @param className the name of the test class - * @param body the body of the test class - * @return the generated code as a string - */ - fun generateCode( - project: Project, - className: String, - body: String, - imports: Set, - packageString: String, - runWith: String, - otherInfo: String, - testGenerationData: TestGenerationData, - ): String { - var testFullText = printUpperPart(className, imports, packageString, runWith, otherInfo) - - // Add each test (exclude expected exception) - testFullText += body - - // close the test class - testFullText += "}" - - testFullText.replace("\r\n", "\n") - - /** - * for better readability and make the tests shorter, we reduce the number of line breaks: - * when we have three or more sequential \n, reduce it to two. - */ - return formatJavaCode(project, Regex("\n\n\n(\n)*").replace(testFullText, "\n\n"), testGenerationData) - } - - /** - * Returns the upper part of test suite (package name, imports, and test class name) as a string. - * - * @return the upper part of test suite (package name, imports, and test class name) as a string. - */ - private fun printUpperPart( - className: String, - imports: Set, - packageString: String, - runWith: String, - otherInfo: String, - ): String { - var testText = "" - - // Add package - if (packageString.isNotBlank()) { - testText += "package $packageString;\n" - } - - // add imports - imports.forEach { importedElement -> - testText += "$importedElement\n" - } - - testText += "\n" - - // add runWith if exists - if (runWith.isNotBlank()) { - testText += "@RunWith($runWith)\n" - } - // open the test class - testText += "public class $className {\n\n" - - // Add other presets (annotations, non-test functions) - if (otherInfo.isNotBlank()) { - testText += otherInfo - } - - return testText - } - - /** - * Finds the test method from a given class with the specified test case name. - * - * @param code The code of the class containing test methods. - * @return The test method as a string, including the "@Test" annotation. - */ - fun getTestMethodCodeFromClassWithTestCase(code: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(code) - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result += "\t" + method.toString().replace("\n", "\n\t") + "\n\n" - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - val upperCutCode = "\t@Test" + code.split("@Test").last() - var methodStarted = false - var balanceOfBrackets = 0 - for (symbol in upperCutCode) { - result += symbol - if (symbol == '{') { - methodStarted = true - balanceOfBrackets++ - } - if (symbol == '}') { - balanceOfBrackets-- - } - if (methodStarted && balanceOfBrackets == 0) { - break - } - } - return result + "\n" - } - } - - /** - * Retrieves the name of the test method from a given Java class with test cases. - * - * @param oldTestCaseName The old name of test case - * @param code The source code of the Java class with test cases. - * @return The name of the test method. If no test method is found, an empty string is returned. - */ - fun getTestMethodNameFromClassWithTestCase(oldTestCaseName: String, code: String): String { - var result = "" - try { - val componentUnit: CompilationUnit = StaticJavaParser.parse(code) - - object : VoidVisitorAdapter() { - override fun visit(method: MethodDeclaration, arg: Any?) { - super.visit(method, arg) - if (method.getAnnotationByName("Test").isPresent) { - result = method.nameAsString - } - } - }.visit(componentUnit, null) - - return result - } catch (e: ParseProblemException) { - return oldTestCaseName - } - } - - /** - * Retrieves the class name from the given test case code. - * - * @param code The test case code to extract the class name from. - * @return The class name extracted from the test case code. - */ - fun getClassFromTestCaseCode(code: String): String { - val pattern = Regex("public\\s+class\\s+(\\S+)\\s*\\{") - val matchResult = pattern.find(code) - matchResult ?: return "GeneratedTest" - val (className) = matchResult.destructured - return className - } - - /** - * Formats the given Java code using IntelliJ IDEA's code formatting rules. - * - * @param code The Java code to be formatted. - * @return The formatted Java code. - */ - fun formatJavaCode(project: Project, code: String, generatedTestData: TestGenerationData): String { - var result = "" - WriteCommandAction.runWriteCommandAction(project) { - val fileName = generatedTestData.resultPath + File.separatorChar + "Formatted.java" - // create a temporary PsiFile - val psiFile: PsiFile = PsiFileFactory.getInstance(project) - .createFileFromText( - fileName, - JavaLanguage.INSTANCE, - code, - ) - - CodeStyleManager.getInstance(project).reformat(psiFile) - - val document = PsiDocumentManager.getInstance(project).getDocument(psiFile) - result = document?.text ?: code - - File(fileName).delete() - } - - return result - } -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt index b36fe381a..d10525087 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt @@ -12,15 +12,19 @@ import org.jetbrains.research.testspark.core.generation.llm.executeTestCaseModif import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestBodyPrinterFactory +import org.jetbrains.research.testspark.tools.TestSuiteParserFactory +import org.jetbrains.research.testspark.tools.TestsAssemblerFactory import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager -import org.jetbrains.research.testspark.tools.llm.generation.JUnitTestsAssembler import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieInfo import org.jetbrains.research.testspark.tools.llm.generation.grazie.GraziePlatform +import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFacePlatform import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIPlatform import java.net.HttpURLConnection import javax.swing.DefaultComboBoxModel @@ -67,6 +71,9 @@ object LLMHelper { if (platformSelector.selectedItem!!.toString() == settingsState.grazieName) { models = getGrazieModels() } + if (platformSelector.selectedItem!!.toString() == settingsState.huggingFaceName) { + models = getHuggingFaceModels() + } modelSelector.model = DefaultComboBoxModel(models) for (index in llmPlatforms.indices) { if (llmPlatforms[index].name == settingsState.openAIName && @@ -81,6 +88,12 @@ object LLMHelper { modelSelector.selectedItem = settingsState.grazieModel llmPlatforms[index].model = modelSelector.selectedItem!!.toString() } + if (llmPlatforms[index].name == settingsState.huggingFaceName && + llmPlatforms[index].name == platformSelector.selectedItem!!.toString() + ) { + modelSelector.selectedItem = settingsState.huggingFaceModel + llmPlatforms[index].model = modelSelector.selectedItem!!.toString() + } } modelSelector.isEnabled = true if (models.contentEquals(arrayOf(""))) modelSelector.isEnabled = false @@ -112,6 +125,12 @@ object LLMHelper { llmUserTokenField.text = settingsState.grazieToken llmPlatforms[index].token = settingsState.grazieToken } + if (llmPlatforms[index].name == settingsState.huggingFaceName && + llmPlatforms[index].name == platformSelector.selectedItem!!.toString() + ) { + llmUserTokenField.text = settingsState.huggingFaceToken + llmPlatforms[index].token = settingsState.huggingFaceToken + } } } @@ -185,8 +204,6 @@ object LLMHelper { if (isGrazieClassLoaded()) { platformSelector.model = DefaultComboBoxModel(llmPlatforms.map { it.name }.toTypedArray()) platformSelector.selectedItem = settingsState.currentLLMPlatformName - } else { - platformSelector.isEnabled = false } llmUserTokenField.toolTipText = LLMSettingsBundle.get("llmToken") @@ -202,7 +219,7 @@ object LLMHelper { * @return The list of LLMPlatforms. */ fun getLLLMPlatforms(): List { - return listOf(OpenAIPlatform(), GraziePlatform()) + return listOf(OpenAIPlatform(), GraziePlatform(), HuggingFacePlatform()) } /** @@ -230,7 +247,7 @@ object LLMHelper { * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ fun testModificationRequest( - language: Language, + language: SupportedLanguage, testCase: String, task: String, indicator: CustomProgressIndicator, @@ -244,13 +261,28 @@ object LLMHelper { return null } + val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + val testBodyPrinter = TestBodyPrinterFactory.create(language) + val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( + jUnitVersion, + language, + testBodyPrinter, + ) + + val testsAssembler = TestsAssemblerFactory.create( + indicator, + testGenerationOutput, + testSuiteParser, + jUnitVersion, + ) + val testSuite = executeTestCaseModificationRequest( language, testCase, task, indicator, requestManager, - testsAssembler = JUnitTestsAssembler(project, indicator, testGenerationOutput), + testsAssembler, errorMonitor, ) return testSuite @@ -328,4 +360,13 @@ object LLMHelper { arrayOf("") } } + + /** + * Retrieves the available HuggingFace models. + * + * @return an array of string representing the available HuggingFace models + */ + private fun getHuggingFaceModels(): Array { + return arrayOf("Meta-Llama-3-8B-Instruct", "Meta-Llama-3-70B-Instruct") + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/CoverageToolWindowDisplayService.kt deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt index e3b11555a..6b257f421 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/TestCaseDisplayService.kt @@ -1,425 +1,69 @@ package org.jetbrains.research.testspark.services -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.openapi.components.Service -import com.intellij.openapi.components.service -import com.intellij.openapi.fileChooser.FileChooser -import com.intellij.openapi.fileChooser.FileChooserDescriptor -import com.intellij.openapi.fileEditor.FileDocumentManager -import com.intellij.openapi.fileEditor.FileEditorManager -import com.intellij.openapi.fileEditor.OpenFileDescriptor -import com.intellij.openapi.fileEditor.TextEditor -import com.intellij.openapi.project.Project -import com.intellij.openapi.vfs.LocalFileSystem -import com.intellij.openapi.vfs.VirtualFile -import com.intellij.openapi.vfs.VirtualFileManager -import com.intellij.openapi.wm.ToolWindowManager -import com.intellij.psi.PsiClass -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiElementFactory -import com.intellij.psi.PsiJavaFile -import com.intellij.psi.PsiManager -import com.intellij.refactoring.suggested.startOffset +import com.intellij.psi.PsiFile import com.intellij.ui.EditorTextField -import com.intellij.ui.JBColor -import com.intellij.ui.components.JBScrollPane -import com.intellij.ui.content.Content -import com.intellij.ui.content.ContentFactory -import com.intellij.ui.content.ContentManager -import com.intellij.util.containers.stream -import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle -import org.jetbrains.research.testspark.bundles.plugin.PluginSettingsBundle import org.jetbrains.research.testspark.core.data.Report -import org.jetbrains.research.testspark.core.data.TestCase -import org.jetbrains.research.testspark.core.utils.Language +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.data.UIContext -import org.jetbrains.research.testspark.display.TestCasePanelFactory -import org.jetbrains.research.testspark.display.TopButtonsPanelFactory -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper -import org.jetbrains.research.testspark.helpers.ReportHelper -import java.awt.BorderLayout -import java.awt.Color -import java.awt.Dimension -import java.io.File -import java.util.Locale -import javax.swing.Box -import javax.swing.BoxLayout -import javax.swing.JButton -import javax.swing.JCheckBox -import javax.swing.JOptionPane +import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper import javax.swing.JPanel -import javax.swing.JSeparator -import javax.swing.SwingConstants -@Service(Service.Level.PROJECT) -class TestCaseDisplayService(private val project: Project) { - private var report: Report? = null - - private val unselectedTestCases = HashMap() - - private var mainPanel: JPanel = JPanel() - - private val topButtonsPanelFactory = TopButtonsPanelFactory(project) - - private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) - - private var allTestCasePanel: JPanel = JPanel() - - private var scrollPane: JBScrollPane = JBScrollPane( - allTestCasePanel, - JBScrollPane.VERTICAL_SCROLLBAR_ALWAYS, - JBScrollPane.HORIZONTAL_SCROLLBAR_NEVER, - ) - - private var testCasePanels: HashMap = HashMap() - - private var testsSelected: Int = 0 - - /** - * Default color for the editors in the tool window - */ - private var defaultEditorColor: Color? = null - - /** - * Content Manager to be able to add / remove tabs from tool window - */ - private var contentManager: ContentManager? = null - - /** - * Variable to keep reference to the coverage visualisation content - */ - private var content: Content? = null - - var uiContext: UIContext? = null - - init { - allTestCasePanel.layout = BoxLayout(allTestCasePanel, BoxLayout.Y_AXIS) - mainPanel.layout = BorderLayout() - - mainPanel.add(topButtonsPanelFactory.getPanel(), BorderLayout.NORTH) - mainPanel.add(scrollPane, BorderLayout.CENTER) - - applyButton.isOpaque = false - applyButton.isContentAreaFilled = false - mainPanel.add(applyButton, BorderLayout.SOUTH) - - applyButton.addActionListener { applyTests() } - } +interface TestCaseDisplayService { /** * Fill the panel with the generated test cases. Remove all previously shown test cases. * Add Tests and their names to a List of pairs (used for highlighting) */ - fun displayTestCases(report: Report, uiContext: UIContext, language: Language) { - this.report = report - this.uiContext = uiContext - - val editor = project.service().editor!! - - allTestCasePanel.removeAll() - testCasePanels.clear() - - addSeparator() - - // TestCasePanelFactories array - val testCasePanelFactories = arrayListOf() - - report.testCaseList.values.forEach { - val testCase = it - val testCasePanel = JPanel() - testCasePanel.layout = BorderLayout() - - // Add a checkbox to select the test - val checkbox = JCheckBox() - checkbox.isSelected = true - checkbox.addItemListener { - // Update the number of selected tests - testsSelected -= (1 - 2 * checkbox.isSelected.compareTo(false)) - - if (checkbox.isSelected) { - ReportHelper.selectTestCase(project, report, unselectedTestCases, testCase.id) - } else { - ReportHelper.unselectTestCase(project, report, unselectedTestCases, testCase.id) - } - - updateUI() - } - testCasePanel.add(checkbox, BorderLayout.WEST) - - val testCasePanelFactory = TestCasePanelFactory(project, language, testCase, editor, checkbox, uiContext, report) - testCasePanel.add(testCasePanelFactory.getUpperPanel(), BorderLayout.NORTH) - testCasePanel.add(testCasePanelFactory.getMiddlePanel(), BorderLayout.CENTER) - testCasePanel.add(testCasePanelFactory.getBottomPanel(), BorderLayout.SOUTH) - - testCasePanelFactories.add(testCasePanelFactory) - - testCasePanel.add(Box.createRigidArea(Dimension(12, 0)), BorderLayout.EAST) - - // Add panel to parent panel - testCasePanel.maximumSize = Dimension(Short.MAX_VALUE.toInt(), Short.MAX_VALUE.toInt()) - allTestCasePanel.add(testCasePanel) - addSeparator() - testCasePanels[testCase.testName] = testCasePanel - } - - // Update the number of selected tests (all tests are selected by default) - testsSelected = testCasePanels.size - - topButtonsPanelFactory.setTestCasePanelFactoriesArray(testCasePanelFactories) - topButtonsPanelFactory.updateTopLabels() - - createToolWindowTab() - } + fun displayTestCases(report: Report, uiContext: UIContext, language: SupportedLanguage) /** * Adds a separator to the allTestCasePanel. */ - private fun addSeparator() { - allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) - allTestCasePanel.add(JSeparator(SwingConstants.HORIZONTAL)) - allTestCasePanel.add(Box.createRigidArea(Dimension(0, 10))) - } + fun addSeparator() /** * Highlight the mini-editor in the tool window whose name corresponds with the name of the test provided * * @param name name of the test whose editor should be highlighted */ - fun highlightTestCase(name: String) { - val myPanel = testCasePanels[name] ?: return - openToolWindowTab() - scrollToPanel(myPanel) - - val editor = getEditor(name) ?: return - val settingsProjectState = project.service().state - val highlightColor = - JBColor( - PluginSettingsBundle.get("colorName"), - Color( - settingsProjectState.colorRed, - settingsProjectState.colorGreen, - settingsProjectState.colorBlue, - 30, - ), - ) - if (editor.background.equals(highlightColor)) return - defaultEditorColor = editor.background - editor.background = highlightColor - returnOriginalEditorBackground(editor) - } + fun highlightTestCase(name: String) /** * Method to open the toolwindow tab with generated tests if not already open. */ - private fun openToolWindowTab() { - val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") - contentManager = toolWindowManager!!.contentManager - if (content != null) { - toolWindowManager.show() - toolWindowManager.contentManager.setSelectedContent(content!!) - } - } + fun openToolWindowTab() /** * Scrolls to the highlighted panel. * * @param myPanel the panel to scroll to */ - private fun scrollToPanel(myPanel: JPanel) { - var sum = 0 - for (component in allTestCasePanel.components) { - if (component == myPanel) { - break - } else { - sum += component.height - } - } - val scroll = scrollPane.verticalScrollBar - scroll.value = (scroll.minimum + scroll.maximum) * sum / allTestCasePanel.height - } + fun scrollToPanel(myPanel: JPanel) /** * Removes all coverage highlighting from the editor. */ - private fun removeAllHighlights() { - project.service().editor?.markupModel?.removeAllHighlighters() - } + fun removeAllHighlights() /** * Reset the provided editors color to the default (initial) one after 10 seconds * @param editor the editor whose color to change */ - private fun returnOriginalEditorBackground(editor: EditorTextField) { - Thread { - Thread.sleep(10000) - editor.background = defaultEditorColor - }.start() - } + fun returnOriginalEditorBackground(editor: EditorTextField) /** * Highlight a range of editors * @param names list of test names to pass to highlight function */ - fun highlightCoveredMutants(names: List) { - names.forEach { - highlightTestCase(it) - } - } + fun highlightCoveredMutants(names: List) /** * Show a dialog where the user can select what test class the tests should be applied to, * and apply the selected tests to the test class. */ - private fun applyTests() { - // Filter the selected test cases - val selectedTestCasePanels = testCasePanels.filter { (it.value.getComponent(0) as JCheckBox).isSelected } - val selectedTestCases = selectedTestCasePanels.map { it.key } - - // Get the test case components (source code of the tests) - val testCaseComponents = selectedTestCases - .map { getEditor(it)!! } - .map { it.document.text } - - // Descriptor for choosing folders and java files - val descriptor = FileChooserDescriptor(true, true, false, false, false, false) - - // Apply filter with folders and java files with main class - WriteCommandAction.runWriteCommandAction(project) { - descriptor.withFileFilter { file -> - file.isDirectory || ( - file.extension?.lowercase(Locale.getDefault()) == "java" && ( - PsiManager.getInstance(project).findFile(file!!) as PsiJavaFile - ).classes.stream().map { it.name } - .toArray() - .contains( - ( - PsiManager.getInstance(project) - .findFile(file) as PsiJavaFile - ).name.removeSuffix(".java"), - ) - ) - } - } - - val fileChooser = FileChooser.chooseFiles( - descriptor, - project, - LocalFileSystem.getInstance().findFileByPath(project.basePath!!), - ) - - /** - * Cancel button pressed - */ - if (fileChooser.isEmpty()) return - - /** - * Chosen files by user - */ - val chosenFile = fileChooser[0] - - /** - * Virtual file of a final java file - */ - var virtualFile: VirtualFile? = null - - /** - * PsiClass of a final java file - */ - var psiClass: PsiClass? = null - - /** - * PsiJavaFile of a final java file - */ - var psiJavaFile: PsiJavaFile? = null - - if (chosenFile.isDirectory) { - // Input new file data - var className: String - var fileName: String - var filePath: String - // Waiting for correct file name input - while (true) { - val jOptionPane = - JOptionPane.showInputDialog( - null, - PluginLabelsBundle.get("optionPaneMessage"), - PluginLabelsBundle.get("optionPaneTitle"), - JOptionPane.PLAIN_MESSAGE, - null, - null, - null, - ) - - // Cancel button pressed - jOptionPane ?: return - - // Get class name from user - className = jOptionPane as String - - // Set file name and file path - fileName = "${className.split('.')[0]}.java" - filePath = "${chosenFile.path}/$fileName" - - // Check the correctness of a class name - if (!Regex("[A-Z][a-zA-Z0-9]*(.java)?").matches(className)) { - showErrorWindow(PluginLabelsBundle.get("incorrectFileNameMessage")) - continue - } - - // Check the existence of a file with this name - if (File(filePath).exists()) { - showErrorWindow(PluginLabelsBundle.get("fileAlreadyExistsMessage")) - continue - } - break - } - - // Create new file and set services of this file - WriteCommandAction.runWriteCommandAction(project) { - chosenFile.createChildData(null, fileName) - virtualFile = VirtualFileManager.getInstance().findFileByUrl("file://$filePath")!! - psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) - psiClass = PsiElementFactory.getInstance(project).createClass(className.split(".")[0]) + fun applyTests() - if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { - psiClass!!.modifierList!!.addAnnotation("RunWith(${uiContext!!.testGenerationOutput.runWith})") - } - - psiJavaFile!!.add(psiClass!!) - } - } else { - // Set services of the chosen file - virtualFile = chosenFile - psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) - psiClass = psiJavaFile!!.classes[ - psiJavaFile!!.classes.stream().map { it.name }.toArray() - .indexOf(psiJavaFile!!.name.removeSuffix(".java")), - ] - } - - // Add tests to the file - WriteCommandAction.runWriteCommandAction(project) { - appendTestsToClass(testCaseComponents, psiClass!!, psiJavaFile!!) - } - - // Remove the selected test cases from the cache and the tool window UI - removeSelectedTestCases(selectedTestCasePanels) - - // Open the file after adding - FileEditorManager.getInstance(project).openTextEditor( - OpenFileDescriptor(project, virtualFile!!), - true, - ) - } - - private fun showErrorWindow(message: String) { - JOptionPane.showMessageDialog( - null, - message, - PluginLabelsBundle.get("errorWindowTitle"), - JOptionPane.ERROR_MESSAGE, - ) - } + fun showErrorWindow(message: String) /** * Retrieve the editor corresponding to a particular test case @@ -427,11 +71,7 @@ class TestCaseDisplayService(private val project: Project) { * @param testCaseName the name of the test case * @return the editor corresponding to the test case, or null if it does not exist */ - fun getEditor(testCaseName: String): EditorTextField? { - val middlePanelComponent = testCasePanels[testCaseName]?.getComponent(2) ?: return null - val middlePanel = middlePanelComponent as JPanel - return (middlePanel.getComponent(1) as JBScrollPane).viewport.view as EditorTextField - } + fun getEditor(testCaseName: String): EditorTextField? /** * Append the provided test cases to the provided class. @@ -440,107 +80,23 @@ class TestCaseDisplayService(private val project: Project) { * @param selectedClass the class which the test cases should be appended to * @param outputFile the output file for tests */ - private fun appendTestsToClass(testCaseComponents: List, selectedClass: PsiClass, outputFile: PsiJavaFile) { - // block document - PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument( - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!, - ) - - // insert tests to a code - testCaseComponents.reversed().forEach { - val testMethodCode = - JavaClassBuilderHelper.getTestMethodCodeFromClassWithTestCase( - JavaClassBuilderHelper.formatJavaCode( - project, - it.replace("\r\n", "\n") - .replace("verifyException(", "// verifyException("), - uiContext!!.testGenerationOutput, - ), - ) - // Fix Windows line separators - .replace("\r\n", "\n") - - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - selectedClass.rBrace!!.textRange.startOffset, - testMethodCode, - ) - } - - // insert other info to a code - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - selectedClass.rBrace!!.textRange.startOffset, - uiContext!!.testGenerationOutput.otherInfo + "\n", - ) - - // insert imports to a code - PsiDocumentManager.getInstance(project).getDocument(outputFile)!!.insertString( - outputFile.importList?.startOffset ?: outputFile.packageStatement?.startOffset ?: 0, - uiContext!!.testGenerationOutput.importsCode.joinToString("\n") + "\n\n", - ) - - // insert package to a code - outputFile.packageStatement ?: PsiDocumentManager.getInstance(project).getDocument(outputFile)!! - .insertString( - 0, - if (uiContext!!.testGenerationOutput.packageLine.isEmpty()) { - "" - } else { - "package ${uiContext!!.testGenerationOutput.packageLine};\n\n" - }, - ) - } + fun appendTestsToClass(testCaseComponents: List, selectedClass: PsiClassWrapper, outputFile: PsiFile) /** * Utility function that returns the editor for a specific file url, * in case it is opened in the IDE */ - fun updateEditorForFileUrl(fileUrl: String) { - val documentManager = FileDocumentManager.getInstance() - // https://intellij-support.jetbrains.com/hc/en-us/community/posts/360004480599/comments/360000703299 - FileEditorManager.getInstance(project).selectedEditors.map { it as TextEditor }.map { it.editor }.map { - val currentFile = documentManager.getFile(it.document) - if (currentFile != null) { - if (currentFile.presentableUrl == fileUrl) { - project.service().editor = it - } - } - } - } + fun updateEditorForFileUrl(fileUrl: String) /** * Creates a new toolWindow tab for the coverage visualisation. */ - private fun createToolWindowTab() { - // Remove generated tests tab from content manager if necessary - val toolWindowManager = ToolWindowManager.getInstance(project).getToolWindow("TestSpark") - contentManager = toolWindowManager!!.contentManager - if (content != null) { - contentManager!!.removeContent(content!!, true) - } - - // If there is no generated tests tab, make it - val contentFactory: ContentFactory = ContentFactory.getInstance() - content = contentFactory.createContent( - mainPanel, - PluginLabelsBundle.get("generatedTests"), - true, - ) - contentManager!!.addContent(content!!) - - // Focus on generated tests tab and open toolWindow if not opened already - contentManager!!.setSelectedContent(content!!) - toolWindowManager.show() - } + fun createToolWindowTab() /** * Closes the tool window and destroys the content of the tab. */ - private fun closeToolWindow() { - contentManager?.removeContent(content!!, true) - ToolWindowManager.getInstance(project).getToolWindow("TestSpark")?.hide() - val coverageVisualisationService = project.service() - coverageVisualisationService.closeToolWindowTab() - } + fun closeToolWindow() /** * Removes the selected tests from the cache, removes all the highlights from the editor and closes the tool window. @@ -549,37 +105,16 @@ class TestCaseDisplayService(private val project: Project) { * * @param selectedTestCasePanels the panels of the selected tests */ - private fun removeSelectedTestCases(selectedTestCasePanels: Map) { - selectedTestCasePanels.forEach { removeTestCase(it.key) } - removeAllHighlights() - closeToolWindow() - } - - fun clear() { - // Remove the tests - val testCasePanelsToRemove = testCasePanels.toMap() - removeSelectedTestCases(testCasePanelsToRemove) + fun removeSelectedTestCases(selectedTestCasePanels: Map) - topButtonsPanelFactory.clear() - } + fun clear() /** * A helper method to remove a test case from the cache and from the UI. * * @param testCaseName the name of the test */ - fun removeTestCase(testCaseName: String) { - // Update the number of selected test cases if necessary - if ((testCasePanels[testCaseName]!!.getComponent(0) as JCheckBox).isSelected) { - testsSelected-- - } - - // Remove the test panel from the UI - allTestCasePanel.remove(testCasePanels[testCaseName]) - - // Remove the test panel - testCasePanels.remove(testCaseName) - } + fun removeTestCase(testCaseName: String) /** * Updates the user interface of the tool window. @@ -589,36 +124,26 @@ class TestCaseDisplayService(private val project: Project) { * of the topButtonsPanel object. It also checks if there are no more tests remaining * and closes the tool window if that is the case. */ - fun updateUI() { - // Update the UI of the tool window tab - allTestCasePanel.updateUI() - - topButtonsPanelFactory.updateTopLabels() - - // If no more tests are remaining, close the tool window - if (testCasePanels.size == 0) closeToolWindow() - } + fun updateUI() /** * Retrieves the list of test case panels. * * @return The list of test case panels. */ - fun getTestCasePanels() = testCasePanels + fun getTestCasePanels(): HashMap /** * Retrieves the currently selected tests. * * @return The list of tests currently selected. */ - fun getTestsSelected() = testsSelected + fun getTestsSelected(): Int /** * Sets the number of tests selected. * * @param testsSelected The number of tests selected. */ - fun setTestsSelected(testsSelected: Int) { - this.testsSelected = testsSelected - } + fun setTestsSelected(testsSelected: Int) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt index 89e480e83..6c3d77a05 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt @@ -45,7 +45,7 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent { // Models private var modelSelector = ComboBox(arrayOf("")) - private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName)) + private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName)) // Default LLM Requests private var defaultLLMRequestsSeparator = diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt index 5f792b328..2b0ff5769 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsConfigurable.kt @@ -42,6 +42,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab settingsComponent!!.llmPlatforms[index].token = llmSettingsState.grazieToken settingsComponent!!.llmPlatforms[index].model = llmSettingsState.grazieModel } + if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + settingsComponent!!.llmPlatforms[index].token = llmSettingsState.huggingFaceToken + settingsComponent!!.llmPlatforms[index].model = llmSettingsState.huggingFaceModel + } } settingsComponent!!.currentLLMPlatformName = llmSettingsState.currentLLMPlatformName settingsComponent!!.maxLLMRequest = llmSettingsState.maxLLMRequest @@ -81,6 +85,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab modified = modified or (settingsComponent!!.llmPlatforms[index].token != llmSettingsState.grazieToken) modified = modified or (settingsComponent!!.llmPlatforms[index].model != llmSettingsState.grazieModel) } + if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + modified = modified or (settingsComponent!!.llmPlatforms[index].token != llmSettingsState.huggingFaceToken) + modified = modified or (settingsComponent!!.llmPlatforms[index].model != llmSettingsState.huggingFaceModel) + } } modified = modified or (settingsComponent!!.currentLLMPlatformName != llmSettingsState.currentLLMPlatformName) modified = modified or (settingsComponent!!.maxLLMRequest != llmSettingsState.maxLLMRequest) @@ -138,6 +146,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab llmSettingsState.grazieToken = settingsComponent!!.llmPlatforms[index].token llmSettingsState.grazieModel = settingsComponent!!.llmPlatforms[index].model } + if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.huggingFaceName) { + llmSettingsState.huggingFaceToken = settingsComponent!!.llmPlatforms[index].token + llmSettingsState.huggingFaceModel = settingsComponent!!.llmPlatforms[index].model + } } llmSettingsState.currentLLMPlatformName = settingsComponent!!.currentLLMPlatformName llmSettingsState.maxLLMRequest = settingsComponent!!.maxLLMRequest diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt index 3ce378707..590ec3c1d 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt @@ -15,6 +15,9 @@ data class LLMSettingsState( var grazieName: String = DefaultLLMSettingsState.grazieName, var grazieToken: String = DefaultLLMSettingsState.grazieToken, var grazieModel: String = DefaultLLMSettingsState.grazieModel, + var huggingFaceName: String = DefaultLLMSettingsState.huggingFaceName, + var huggingFaceToken: String = DefaultLLMSettingsState.huggingFaceToken, + var huggingFaceModel: String = DefaultLLMSettingsState.huggingFaceModel, var currentLLMPlatformName: String = DefaultLLMSettingsState.currentLLMPlatformName, var maxLLMRequest: Int = DefaultLLMSettingsState.maxLLMRequest, var maxInputParamsDepth: Int = DefaultLLMSettingsState.maxInputParamsDepth, @@ -45,6 +48,9 @@ data class LLMSettingsState( val grazieName: String = LLMDefaultsBundle.get("grazieName") val grazieToken: String = LLMDefaultsBundle.get("grazieToken") val grazieModel: String = LLMDefaultsBundle.get("grazieModel") + val huggingFaceName: String = LLMDefaultsBundle.get("huggingFaceName") + val huggingFaceToken: String = LLMDefaultsBundle.get("huggingFaceToken") + val huggingFaceModel: String = LLMDefaultsBundle.get("huggingFaceModel") var currentLLMPlatformName: String = LLMDefaultsBundle.get("openAIName") val maxLLMRequest: Int = LLMDefaultsBundle.get("maxLLMRequest").toInt() val maxInputParamsDepth: Int = LLMDefaultsBundle.get("maxInputParamsDepth").toInt() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt index 0cd1b073a..c4310ba61 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/LibraryPathsProvider.kt @@ -2,7 +2,7 @@ package org.jetbrains.research.testspark.tools import com.intellij.openapi.application.PathManager import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.test.data.dependencies.JavaTestCompilationDependencies +import org.jetbrains.research.testspark.core.test.data.dependencies.TestCompilationDependencies import java.io.File /** @@ -16,7 +16,7 @@ class LibraryPathsProvider { private val sep = File.separatorChar private val libPrefix = "${PathManager.getPluginsPath()}${sep}TestSpark${sep}lib$sep" - fun getTestCompilationLibraryPaths() = JavaTestCompilationDependencies.getJarDescriptors().map { descriptor -> + fun getTestCompilationLibraryPaths() = TestCompilationDependencies.getJarDescriptors().map { descriptor -> "$libPrefix${sep}${descriptor.name}" } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt index aa5b694b7..30ed0ba6b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt @@ -6,12 +6,12 @@ import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.progress.Task import com.intellij.openapi.project.Project -import com.intellij.openapi.roots.ProjectFileIndex import com.intellij.openapi.roots.ProjectRootManager import com.intellij.openapi.util.io.FileUtilRt import org.jetbrains.research.testspark.actions.controllers.TestGenerationController import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.utils.DataFilesUtil import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.ProjectContext @@ -22,6 +22,8 @@ import org.jetbrains.research.testspark.services.CoverageVisualisationService import org.jetbrains.research.testspark.services.EditorService import org.jetbrains.research.testspark.services.TestCaseDisplayService import org.jetbrains.research.testspark.services.TestsExecutionResultService +import org.jetbrains.research.testspark.services.java.JavaTestCaseDisplayService +import org.jetbrains.research.testspark.services.kotlin.KotlinTestCaseDisplayService import org.jetbrains.research.testspark.tools.template.generation.ProcessManager import java.util.UUID @@ -29,7 +31,7 @@ import java.util.UUID * Pipeline class represents a pipeline for generating tests in a project. * * @param project the project in which the pipeline is executed. - * @param psiHelper The PsiHelper in the context of witch the pipeline is executed. + * @param psiHelper The PsiHelper in the context of which the pipeline is executed. * @param caretOffset the offset of the caret position in the PSI file. * @param fileUrl the URL of the file being processed, if applicable. * @param packageName the package name of the file being processed. @@ -47,7 +49,7 @@ class Pipeline( init { - val cutPsiClass = psiHelper.getSurroundingClass(caretOffset)!! + val cutPsiClass = psiHelper.getSurroundingClass(caretOffset) // get generated test path val testResultDirectory = "${FileUtilRt.getTempDirectory()}${ToolUtils.sep}testSparkResults${ToolUtils.sep}" @@ -57,10 +59,8 @@ class Pipeline( ApplicationManager.getApplication().runWriteAction { projectContext.projectClassPath = ProjectRootManager.getInstance(project).contentRoots.first().path projectContext.fileUrlAsString = fileUrl - projectContext.classFQN = cutPsiClass.qualifiedName - // TODO probably can be made easier - projectContext.cutModule = - ProjectFileIndex.getInstance(project).getModuleForFile(cutPsiClass.virtualFile)!! + cutPsiClass?.let { projectContext.classFQN = it.qualifiedName } + projectContext.cutModule = psiHelper.getModuleFromPsiFile() } generatedTestsData.resultPath = ToolUtils.getResultPath(id, testResultDirectory) @@ -108,14 +108,13 @@ class Pipeline( override fun onFinished() { super.onFinished() testGenerationController.finished() - uiContext?.let { - project.service() - .updateEditorForFileUrl(it.testGenerationOutput.fileUrl) - - if (project.service().editor != null) { - val report = it.testGenerationOutput.testGenerationResultList[0]!! - project.service().displayTestCases(report, it, psiHelper.language) - project.service().showCoverage(report) + when (psiHelper.language) { + SupportedLanguage.Java -> uiContext?.let { + displayTestCase(it) + } + + SupportedLanguage.Kotlin -> uiContext?.let { + displayTestCase(it) } } } @@ -124,8 +123,22 @@ class Pipeline( private fun clear(project: Project) { // should be removed totally! testGenerationController.errorMonitor.clear() - project.service().clear() + when (psiHelper.language) { + SupportedLanguage.Java -> project.service().clear() + SupportedLanguage.Kotlin -> project.service().clear() + } + project.service().clear() project.service().clear() } + + private inline fun displayTestCase(ctx: UIContext) { + project.service().updateEditorForFileUrl(ctx.testGenerationOutput.fileUrl) + + if (project.service().editor != null) { + val report = ctx.testGenerationOutput.testGenerationResultList[0]!! + project.service().displayTestCases(report, ctx, psiHelper.language) + project.service().showCoverage(report) + } + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt index 8680370bd..84b512bb5 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestCompilerFactory.kt @@ -3,20 +3,31 @@ package org.jetbrains.research.testspark.tools import com.intellij.openapi.project.Project import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.core.data.JUnitVersion +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler +import org.jetbrains.research.testspark.core.test.java.JavaTestCompiler +import org.jetbrains.research.testspark.core.test.kotlin.KotlinTestCompiler class TestCompilerFactory { companion object { - fun createJavacTestCompiler( + fun create( project: Project, junitVersion: JUnitVersion, + language: SupportedLanguage, javaHomeDirectory: String? = null, ): TestCompiler { - val javaHomePath = javaHomeDirectory ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + val javaSDKHomePath = + javaHomeDirectory ?: ProjectRootManager.getInstance(project).projectSdk?.homeDirectory?.path + ?: throw RuntimeException("Java SDK not configured for the project.") + val libraryPaths = LibraryPathsProvider.getTestCompilationLibraryPaths() val junitLibraryPaths = LibraryPathsProvider.getJUnitLibraryPaths(junitVersion) - return TestCompiler(javaHomePath, libraryPaths, junitLibraryPaths) + // TODO add the warning window that for Java we always need the javaHomeDirectoryPath + return when (language) { + SupportedLanguage.Java -> JavaTestCompiler(libraryPaths, junitLibraryPaths, javaSDKHomePath) + SupportedLanguage.Kotlin -> KotlinTestCompiler(libraryPaths, junitLibraryPaths) + } } } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt index e0a4150b4..d35589357 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/TestProcessor.kt @@ -8,6 +8,7 @@ import com.intellij.openapi.roots.CompilerModuleExtension import com.intellij.openapi.roots.ModuleRootManager import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.core.data.TestCase +import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.utils.CommandLineRunner import org.jetbrains.research.testspark.core.utils.DataFilesUtil @@ -25,16 +26,20 @@ class TestProcessor( val project: Project, givenProjectSDKPath: Path? = null, ) : TestsPersistentStorage { - private val javaHomeDirectory = givenProjectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + private val homeDirectory = + givenProjectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path private val log = Logger.getInstance(this::class.java) private val llmSettingsState: LLMSettingsState get() = project.getService(LLMSettingsService::class.java).state - val testCompiler = TestCompilerFactory.createJavacTestCompiler(project, llmSettingsState.junitVersion, javaHomeDirectory) - - override fun saveGeneratedTest(packageString: String, code: String, resultPath: String, testFileName: String): String { + override fun saveGeneratedTest( + packageString: String, + code: String, + resultPath: String, + testFileName: String, + ): String { // Generate the final path for the generated tests var generatedTestPath = "$resultPath${File.separatorChar}" packageString.split(".").forEach { directory -> @@ -69,14 +74,10 @@ class TestProcessor( generatedTestPackage: String, resultPath: String, projectContext: ProjectContext, + testCompiler: TestCompiler, ): String { // find the proper javac - val javaRunner = File(javaHomeDirectory).walk() - .filter { - val isJavaName = if (DataFilesUtil.isWindows()) it.name.equals("java.exe") else it.name.equals("java") - isJavaName && it.isFile - } - .first() + val javaRunner = findJavaCompilerInDirectory(homeDirectory) // JaCoCo libs val jacocoAgentLibraryPath = "\"${LibraryPathsProvider.getJacocoAgentLibraryPath()}\"" val jacocoCLILibraryPath = "\"${LibraryPathsProvider.getJacocoCliLibraryPath()}\"" @@ -90,13 +91,21 @@ class TestProcessor( val junitVersion = llmSettingsState.junitVersion.version // run the test method with jacoco agent + log.info("[TestProcessor] Executing $name") val junitRunnerLibraryPath = LibraryPathsProvider.getJUnitRunnerLibraryPath() + // classFQN will be null for the top level function + val javaAgentFlag = + if (projectContext.classFQN != null) { + "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false,includes=${projectContext.classFQN}" + } else { + "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false" + } val testExecutionError = CommandLineRunner.run( arrayListOf( javaRunner.absolutePath, - "-javaagent:$jacocoAgentLibraryPath=destfile=$dataFileName.exec,append=false,includes=${projectContext.classFQN}", + javaAgentFlag, "-cp", - "\"${testCompiler.getPath(projectBuildPath)}${DataFilesUtil.classpathSeparator}${junitRunnerLibraryPath}${DataFilesUtil.classpathSeparator}$resultPath\"", + "\"${testCompiler.getClassPaths(projectBuildPath)}${DataFilesUtil.classpathSeparator}${junitRunnerLibraryPath}${DataFilesUtil.classpathSeparator}$resultPath\"", "org.jetbrains.research.SingleJUnitTestRunner$junitVersion", name, ), @@ -148,9 +157,10 @@ class TestProcessor( testId: Int, testName: String, testCode: String, - packageLine: String, + packageName: String, resultPath: String, projectContext: ProjectContext, + testCompiler: TestCompiler, ): TestCase { // get buildPath var buildPath: String = ProjectRootManager.getInstance(project).contentRoots.first().path @@ -161,7 +171,7 @@ class TestProcessor( // save new test to file val generatedTestPath: String = saveGeneratedTest( - packageLine, + packageName, testCode, resultPath, fileName, @@ -179,9 +189,10 @@ class TestProcessor( dataFileName, testName, buildPath, - packageLine, + packageName, resultPath, projectContext, + testCompiler, ) if (!File("$dataFileName.xml").exists()) { @@ -230,7 +241,8 @@ class TestProcessor( frames.removeFirst() frames.forEach { frame -> - if (frame.contains(projectContext.classFQN!!)) { + // classFQN will be null for the top level function + if (projectContext.classFQN != null && frame.contains(projectContext.classFQN!!)) { val coveredLineNumber = frame.split(":")[1].replace(")", "").toIntOrNull() if (coveredLineNumber != null) { result.add(coveredLineNumber) @@ -274,7 +286,8 @@ class TestProcessor( children("counter") {} } children("sourcefile") { - isCorrectSourceFile = this.attributes.getValue("name") == projectContext.fileUrlAsString!!.split(File.separatorChar).last() + isCorrectSourceFile = + this.attributes.getValue("name") == projectContext.fileUrlAsString!!.split(File.separatorChar).last() children("line") { if (isCorrectSourceFile && this.attributes.getValue("mi") == "0") { setOfLines.add(this.attributes.getValue("nr").toInt()) @@ -295,4 +308,18 @@ class TestProcessor( return TestCase(testCaseId, testCaseName, testCaseCode, setOfLines) } + + /** + * Finds 'javac' compiler (both on Unix & Windows) + * starting from the provided directory. + */ + private fun findJavaCompilerInDirectory(homeDirectory: String): File { + return File(homeDirectory).walk() + .filter { + val isJavaName = + if (DataFilesUtil.isWindows()) it.name.equals("java.exe") else it.name.equals("java") + isJavaName && it.isFile + } + .first() + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt index 3ba26b9c5..a7ef25eb2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt @@ -11,9 +11,9 @@ import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.utils.DataFilesUtil import org.jetbrains.research.testspark.data.IJTestCase -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper import org.jetbrains.research.testspark.services.TestsExecutionResultService import java.io.File @@ -21,68 +21,37 @@ object ToolUtils { val sep = File.separatorChar val pathSep = File.pathSeparatorChar - /** - * Retrieves the imports code from a given test suite code. - * - * @param testSuiteCode The test suite code from which to extract the imports code. If null, an empty string is returned. - * @param classFQN The fully qualified name of the class to be excluded from the imports code. It will not be included in the result. - * @return The imports code extracted from the test suite code. If no imports are found or the result is empty after filtering, an empty string is returned. - */ - fun getImportsCodeFromTestSuiteCode(testSuiteCode: String?, classFQN: String): MutableSet { - testSuiteCode ?: return mutableSetOf() - return testSuiteCode.replace("\r\n", "\n").split("\n").asSequence() - .filter { it.contains("^import".toRegex()) } - .filterNot { it.contains("evosuite".toRegex()) } - .filterNot { it.contains("RunWith".toRegex()) } - .filterNot { it.contains(classFQN.toRegex()) }.toMutableSet() - } - - /** - * Retrieves the package declaration from the given test suite code. - * - * @param testSuiteCode The generated code of the test suite. - * @return The package declaration extracted from the test suite code, or an empty string if no package declaration was found. - */ -// get package from a generated code - fun getPackageFromTestSuiteCode(testSuiteCode: String?): String { - testSuiteCode ?: return "" - if (!testSuiteCode.contains("package")) return "" - val result = testSuiteCode.replace("\r\n", "\n").split("\n") - .filter { it.contains("^package".toRegex()) }.joinToString("").split("package ")[1].split(";")[0] - if (result.isBlank()) return "" - return result - } - /** * Saves the data related to test generation in the specified project's workspace. * * @param project The project in which the test generation data will be saved. * @param report The report object to be added to the test generation result list. - * @param packageLine The package declaration line of the test generation data. + * @param packageName The package declaration line of the test generation data. * @param importsCode The import statements code of the test generation data. */ fun saveData( project: Project, report: Report, - packageLine: String, + packageName: String, importsCode: MutableSet, fileUrl: String, generatedTestData: TestGenerationData, + language: SupportedLanguage = SupportedLanguage.Java, ) { generatedTestData.fileUrl = fileUrl - generatedTestData.packageLine = packageLine + generatedTestData.packageName = packageName generatedTestData.importsCode.addAll(importsCode) project.service().initExecutionResult(report.testCaseList.values.map { it.id }) for (testCase in report.testCaseList.values) { val code = testCase.testCode - testCase.testCode = JavaClassBuilderHelper.generateCode( + testCase.testCode = TestClassCodeGeneratorFactory.create(language).generateCode( project, getClassWithTestCaseName(testCase.testName), code, generatedTestData.importsCode, - generatedTestData.packageLine, + generatedTestData.packageName, generatedTestData.runWith, generatedTestData.otherInfo, generatedTestData, diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt index 46b982ac1..529bb4b8e 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt @@ -5,7 +5,7 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.actions.controllers.TestGenerationController -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.langwrappers.PsiHelper import org.jetbrains.research.testspark.langwrappers.PsiMethodWrapper @@ -88,7 +88,7 @@ class EvoSuite(override val name: String = "EvoSuite") : Tool { */ override fun generateTestsForLine(project: Project, psiHelper: PsiHelper, caretOffset: Int, fileUrl: String?, testSamplesCode: String, testGenerationController: TestGenerationController) { log.info("Starting tests generation for line by EvoSuite") - val selectedLine: Int = psiHelper.getSurroundingLine(caretOffset)!! + val selectedLine: Int = psiHelper.getSurroundingLineNumber(caretOffset)!! createPipeline(project, psiHelper, caretOffset, fileUrl, testGenerationController).runTestGeneration( getEvoSuiteProcessManager(project), FragmentToTestData( diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt index c1e5e6560..8c180f9df 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt @@ -15,10 +15,13 @@ import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteDefaultsBundle import org.jetbrains.research.testspark.bundles.evosuite.EvoSuiteMessagesBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.core.utils.CommandLineRunner -import org.jetbrains.research.testspark.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.IJReport import org.jetbrains.research.testspark.data.ProjectContext @@ -200,8 +203,8 @@ class EvoSuiteProcessManager( ToolUtils.saveData( project, IJReport(testGenerationResult), - ToolUtils.getPackageFromTestSuiteCode(testGenerationResult.testSuiteCode), - ToolUtils.getImportsCodeFromTestSuiteCode(testGenerationResult.testSuiteCode, classFQN), + getPackageFromTestSuiteCode(testGenerationResult.testSuiteCode, SupportedLanguage.Java), + getImportsCodeFromTestSuiteCode(testGenerationResult.testSuiteCode, classFQN), projectContext.fileUrlAsString!!, generatedTestsData, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt index 01f16176c..980707a2a 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt @@ -1,11 +1,12 @@ package org.jetbrains.research.testspark.tools.llm import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.actions.controllers.TestGenerationController import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.helpers.LLMHelper import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper @@ -23,6 +24,8 @@ import java.nio.file.Path */ class Llm(override val name: String = "LLM") : Tool { + private val log = Logger.getInstance(this::class.java) + /** * Returns an instance of the LLMProcessManager. * @@ -74,6 +77,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests for CLASS was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -107,6 +111,7 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests for METHOD was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return @@ -141,11 +146,12 @@ class Llm(override val name: String = "LLM") : Tool { testSamplesCode: String, testGenerationController: TestGenerationController, ) { + log.info("Generation of tests for LINE was selected") if (!LLMHelper.isCorrectToken(project, testGenerationController.errorMonitor)) { testGenerationController.finished() return } - val selectedLine: Int = psiHelper.getSurroundingLine(caretOffset)!! + val selectedLine: Int = psiHelper.getSurroundingLineNumber(caretOffset)!! val codeType = FragmentToTestData(CodeType.LINE, selectedLine) createLLMPipeline(project, psiHelper, caretOffset, fileUrl, testGenerationController).runTestGeneration( LLMProcessManager( @@ -174,9 +180,7 @@ class Llm(override val name: String = "LLM") : Tool { fileUrl: String?, testGenerationController: TestGenerationController, ): Pipeline { - val cutPsiClass = psiHelper.getSurroundingClass(caretOffset)!! - val packageList = cutPsiClass.qualifiedName.split(".").dropLast(1) - val packageName = packageList.joinToString(".") + val packageName = psiHelper.getPackageName() return Pipeline(project, psiHelper, caretOffset, fileUrl, packageName, testGenerationController) } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt index 437ecd679..271cf4b49 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt @@ -57,6 +57,7 @@ class LlmSettingsArguments(private val project: Project) { fun getToken(): String = when (currentLLMPlatformName()) { llmSettingsState.openAIName -> llmSettingsState.openAIToken llmSettingsState.grazieName -> llmSettingsState.grazieToken + llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceToken else -> "" } @@ -68,6 +69,7 @@ class LlmSettingsArguments(private val project: Project) { fun getModel(): String = when (currentLLMPlatformName()) { llmSettingsState.openAIName -> llmSettingsState.openAIModel llmSettingsState.grazieName -> llmSettingsState.grazieModel + llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceModel else -> "" } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt index e1bcb67ec..1196016b2 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt @@ -1,36 +1,27 @@ package org.jetbrains.research.testspark.tools.llm.generation import com.intellij.openapi.diagnostic.Logger -import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.TestSuiteParser import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.test.parsers.TestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.java.JavaJUnitTestSuiteParser -import org.jetbrains.research.testspark.core.test.parsers.kotlin.KotlinJUnitTestSuiteParser -import org.jetbrains.research.testspark.core.utils.Language -import org.jetbrains.research.testspark.core.utils.javaImportPattern -import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState /** * Assembler class for generating and organizing test cases. * - * @property project The project to which the tests belong. * @property indicator The progress indicator to display the progress of test generation. * @property log The logger for logging debug information. * @property lastTestCount The count of the last generated tests. */ class JUnitTestsAssembler( - val project: Project, val indicator: CustomProgressIndicator, - val generationData: TestGenerationData, + private val generationData: TestGenerationData, + private val testSuiteParser: TestSuiteParser, + val junitVersion: JUnitVersion, ) : TestsAssembler() { - private val llmSettingsState: LLMSettingsState - get() = project.getService(LLMSettingsService::class.java).state private val log: Logger = Logger.getInstance(this.javaClass) @@ -58,11 +49,8 @@ class JUnitTestsAssembler( } } - override fun assembleTestSuite(packageName: String, language: Language): TestSuiteGeneratedByLLM? { - val junitVersion = llmSettingsState.junitVersion - - val parser = createTestSuiteParser(packageName, junitVersion, language) - val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(super.getContent()) + override fun assembleTestSuite(): TestSuiteGeneratedByLLM? { + val testSuite = testSuiteParser.parseTestSuite(super.getContent()) // save RunWith if (testSuite?.runWith?.isNotBlank() == true) { @@ -80,15 +68,4 @@ class JUnitTestsAssembler( testSuite?.testCases?.forEach { testCase -> log.info("Generated test case: $testCase") } return testSuite } - - private fun createTestSuiteParser( - packageName: String, - jUnitVersion: JUnitVersion, - language: Language, - ): TestSuiteParser { - return when (language) { - Language.Java -> JavaJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern) - Language.Kotlin -> KotlinJUnitTestSuiteParser(packageName, jUnitVersion, javaImportPattern) - } - } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt index bb1dee0ff..f46dd5603 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt @@ -3,25 +3,32 @@ package org.jetbrains.research.testspark.tools.llm.generation import com.intellij.openapi.components.service import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project +import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.FeedbackCycleExecutionResult import org.jetbrains.research.testspark.core.generation.llm.LLMWithFeedbackCycle +import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode +import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeReductionStrategy import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.SupportedLanguage +import org.jetbrains.research.testspark.core.test.TestsPersistentStorage import org.jetbrains.research.testspark.core.test.TestsPresenter import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.core.utils.Language import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.IJReport import org.jetbrains.research.testspark.data.ProjectContext import org.jetbrains.research.testspark.data.UIContext import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.services.PluginSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState +import org.jetbrains.research.testspark.tools.TestBodyPrinterFactory +import org.jetbrains.research.testspark.tools.TestCompilerFactory import org.jetbrains.research.testspark.tools.TestProcessor +import org.jetbrains.research.testspark.tools.TestSuiteParserFactory +import org.jetbrains.research.testspark.tools.TestsAssemblerFactory import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager @@ -34,7 +41,6 @@ import java.nio.file.Path * and is responsible for generating tests using the LLM tool. * * @property project The project in which the test generation is being performed. - * @property prompt The prompt to be sent to the LLM tool. * @property testFileName The name of the generated test file. * @property log An instance of the logger class for logging purposes. * @property llmErrorManager An instance of the LLMErrorManager class. @@ -42,19 +48,23 @@ import java.nio.file.Path */ class LLMProcessManager( private val project: Project, - private val language: Language, + private val language: SupportedLanguage, private val promptManager: PromptManager, private val testSamplesCode: String, - projectSDKPath: Path? = null, + private val projectSDKPath: Path? = null, ) : ProcessManager { - private val llmSettingsState: LLMSettingsState - get() = project.getService(LLMSettingsService::class.java).state - private val testFileName: String = "GeneratedTest.java" + private val homeDirectory = + projectSDKPath?.toString() ?: ProjectRootManager.getInstance(project).projectSdk!!.homeDirectory!!.path + + private val testFileName: String = when (language) { + SupportedLanguage.Java -> "GeneratedTest.java" + SupportedLanguage.Kotlin -> "GeneratedTest.kt" + } private val log = Logger.getInstance(this::class.java) private val llmErrorManager: LLMErrorManager = LLMErrorManager() private val maxRequests = LlmSettingsArguments(project).maxLLMRequest() - private val testProcessor = TestProcessor(project, projectSDKPath) + private val testProcessor: TestsPersistentStorage = TestProcessor(project, projectSDKPath) /** * Runs the test generator process. @@ -91,16 +101,16 @@ class LLMProcessManager( val report = IJReport() // PROMPT GENERATION - val initialPromptMessage = promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) - - val testCompiler = testProcessor.testCompiler + val initialPromptMessage = + promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) // initiate a new RequestManager val requestManager = StandardRequestManagerFactory(project).getRequestManager(project) // adapter for the existing prompt reduction functionality val promptSizeReductionStrategy = object : PromptSizeReductionStrategy { - override fun isReductionPossible(): Boolean = promptManager.isPromptSizeReductionPossible(generatedTestsData) + override fun isReductionPossible(): Boolean = + promptManager.isPromptSizeReductionPossible(generatedTestsData) override fun reduceSizeAndGeneratePrompt(): String { if (!isReductionPossible()) { @@ -115,7 +125,7 @@ class LLMProcessManager( // adapter for the existing test case/test suite string representing functionality val testsPresenter = object : TestsPresenter { - private val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData) + private val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) override fun representTestSuite(testSuite: TestSuiteGeneratedByLLM): String { return testSuitePresenter.toStringWithoutExpectedException(testSuite) @@ -126,6 +136,29 @@ class LLMProcessManager( } } + // Creation of JUnit specific parser, printer and assembler + val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + val testBodyPrinter = TestBodyPrinterFactory.create(language) + val testSuiteParser = TestSuiteParserFactory.createJUnitTestSuiteParser( + jUnitVersion, + language, + testBodyPrinter, + packageName, + ) + val testsAssembler = TestsAssemblerFactory.create( + indicator, + generatedTestsData, + testSuiteParser, + jUnitVersion, + ) + + val testCompiler = TestCompilerFactory.create( + project, + jUnitVersion, + language, + homeDirectory, + ) + // Asking LLM to generate a test suite. Here we have a feedback cycle for LLM in case of wrong responses val llmFeedbackCycle = LLMWithFeedbackCycle( language = language, @@ -137,7 +170,7 @@ class LLMProcessManager( resultPath = generatedTestsData.resultPath, buildPath = buildPath, requestManager = requestManager, - testsAssembler = JUnitTestsAssembler(project, indicator, generatedTestsData), + testsAssembler = testsAssembler, testCompiler = testCompiler, testStorage = testProcessor, testsPresenter = testsPresenter, @@ -150,8 +183,10 @@ class LLMProcessManager( when (warning) { LLMWithFeedbackCycle.WarningType.TEST_SUITE_PARSING_FAILED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("emptyResponse"), project) + LLMWithFeedbackCycle.WarningType.NO_TEST_CASES_GENERATED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("emptyResponse"), project) + LLMWithFeedbackCycle.WarningType.COMPILATION_ERROR_OCCURRED -> llmErrorManager.warningProcess(LLMMessagesBundle.get("compilationError"), project) } @@ -167,17 +202,21 @@ class LLMProcessManager( // store compilable test cases generatedTestsData.compilableTestCases.addAll(feedbackResponse.compilableTestCases) } + FeedbackCycleExecutionResult.NO_COMPILABLE_TEST_CASES_GENERATED -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("invalidLLMResult"), project, errorMonitor) } + FeedbackCycleExecutionResult.CANCELED -> { log.info("Process stopped") return null } + FeedbackCycleExecutionResult.PROVIDED_PROMPT_TOO_LONG -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("tooLongPromptRequest"), project, errorMonitor) return null } + FeedbackCycleExecutionResult.SAVING_TEST_FILES_ISSUE -> { llmErrorManager.errorProcess(LLMMessagesBundle.get("savingTestFileIssue"), project, errorMonitor) } @@ -190,7 +229,7 @@ class LLMProcessManager( log.info("Save generated test suite and test cases into the project workspace") - val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData) + val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) val generatedTestSuite: TestSuiteGeneratedByLLM? = feedbackResponse.generatedTestSuite val testSuiteRepresentation = if (generatedTestSuite != null) testSuitePresenter.toString(generatedTestSuite) else null @@ -200,10 +239,11 @@ class LLMProcessManager( ToolUtils.saveData( project, report, - ToolUtils.getPackageFromTestSuiteCode(testSuiteCode = testSuiteRepresentation), - ToolUtils.getImportsCodeFromTestSuiteCode(testSuiteRepresentation, projectContext.classFQN!!), + getPackageFromTestSuiteCode(testSuiteCode = testSuiteRepresentation, language), + getImportsCodeFromTestSuiteCode(testSuiteRepresentation, projectContext.classFQN), projectContext.fileUrlAsString!!, generatedTestsData, + language, ) return UIContext(projectContext, generatedTestsData, requestManager, errorMonitor) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt index d7ac8f9f5..08e5be765 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt @@ -5,7 +5,6 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.openapi.util.Computable import com.intellij.openapi.util.TextRange -import com.intellij.psi.PsiDocumentManager import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.llm.LLMSettingsBundle import org.jetbrains.research.testspark.core.data.TestGenerationData @@ -15,7 +14,7 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptConfiguration import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptGenerationContext import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.PromptTemplates -import org.jetbrains.research.testspark.data.CodeType +import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.llm.JsonEncoding import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper @@ -31,7 +30,7 @@ import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager * A class that manages prompts for generating unit tests. * * @constructor Creates a PromptManager with the given parameters. - * @param psiHelper The PsiHelper in the context of witch the pipeline is executed. + * @param psiHelper The PsiHelper in the context of which the pipeline is executed. * @param caret The place of the caret. */ class PromptManager( @@ -39,6 +38,9 @@ class PromptManager( private val psiHelper: PsiHelper, private val caret: Int, ) { + /** + * The `classesToTest` is empty when we work with the function outside the class + */ private val classesToTest: List get() { val classesToTest = mutableListOf() @@ -52,7 +54,10 @@ class PromptManager( return classesToTest } - private val cut: PsiClassWrapper = classesToTest[0] + /** + * The `cut` is null when we work with the function outside the class. + */ + private val cut: PsiClassWrapper? = if (classesToTest.isNotEmpty()) classesToTest[0] else null private val llmSettingsState: LLMSettingsState get() = project.getService(LLMSettingsService::class.java).state @@ -79,7 +84,7 @@ class PromptManager( .toMap() val context = PromptGenerationContext( - cut = createClassRepresentation(cut), + cut = cut?.let { createClassRepresentation(it) }, classesToTest = classesToTest.map(this::createClassRepresentation).toList(), polymorphismRelations = polymorphismRelations, promptConfiguration = PromptConfiguration( @@ -110,7 +115,12 @@ class PromptManager( .map(this::createClassRepresentation) .toList() - promptGenerator.generatePromptForMethod(method, interestingClassesFromMethod, testSamplesCode) + promptGenerator.generatePromptForMethod( + method, + interestingClassesFromMethod, + testSamplesCode, + psiHelper.getPackageName(), + ) } CodeType.LINE -> { @@ -118,7 +128,7 @@ class PromptManager( val psiMethod = getPsiMethod(cut, getMethodDescriptor(cut, lineNumber))!! // get code of line under test - val document = PsiDocumentManager.getInstance(project).getDocument(cut.containingFile) + val document = psiHelper.getDocumentFromPsiFile() val lineStartOffset = document!!.getLineStartOffset(lineNumber - 1) val lineEndOffset = document.getLineEndOffset(lineNumber - 1) @@ -149,7 +159,7 @@ class PromptManager( signature = psiMethod.signature, name = psiMethod.name, text = psiMethod.text!!, - containingClassQualifiedName = psiMethod.containingClass!!.qualifiedName, + containingClassQualifiedName = psiMethod.containingClass?.qualifiedName ?: "", ) } @@ -210,7 +220,6 @@ class PromptManager( * * @param project The project context in which the PsiClasses exist. * @param interestingPsiClasses The set of PsiClassWrappers that are considered interesting. - * @param cutPsiClass The cut PsiClassWrapper to determine polymorphism relations against. * @return A mutable map where the key represents an interesting PsiClass and the value is a list of its detected subclasses. */ private fun getPolymorphismRelationsWithQualifiedNames( @@ -219,6 +228,9 @@ class PromptManager( ): MutableMap> { val polymorphismRelations: MutableMap> = mutableMapOf() + // assert(interestingPsiClasses.isEmpty()) + if (cut == null) return polymorphismRelations + interestingPsiClasses.add(cut) interestingPsiClasses.forEach { currentInterestingClass -> @@ -245,9 +257,14 @@ class PromptManager( * @return The matching PsiMethod if found, otherwise an empty string. */ private fun getPsiMethod( - psiClass: PsiClassWrapper, + psiClass: PsiClassWrapper?, methodDescriptor: String, ): PsiMethodWrapper? { + // Processing function outside the class + if (psiClass == null) { + val currentPsiMethod = psiHelper.getSurroundingMethod(caret)!! + return currentPsiMethod + } for (currentPsiMethod in psiClass.allMethods) { val file = psiClass.containingFile val psiHelper = PsiHelperProvider.getPsiHelper(file) @@ -268,9 +285,14 @@ class PromptManager( * @return the method descriptor as a String, or an empty string if no method is found */ private fun getMethodDescriptor( - psiClass: PsiClassWrapper, + psiClass: PsiClassWrapper?, lineNumber: Int, ): String { + // Processing function outside the class + if (psiClass == null) { + val currentPsiMethod = psiHelper.getSurroundingMethod(caret)!! + return psiHelper.generateMethodDescriptor(currentPsiMethod) + } for (currentPsiMethod in psiClass.allMethods) { if (currentPsiMethod.containsLine(lineNumber)) { val file = psiClass.containingFile diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt index 46daefc30..f05d55986 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt @@ -6,6 +6,7 @@ import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieRequestManager +import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFaceRequestManager import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIRequestManager interface RequestManagerFactory { @@ -20,6 +21,7 @@ class StandardRequestManagerFactory(private val project: Project) : RequestManag return when (val platform = LlmSettingsArguments(project).currentLLMPlatformName()) { llmSettingsState.openAIName -> OpenAIRequestManager(project) llmSettingsState.grazieName -> GrazieRequestManager(project) + llmSettingsState.huggingFaceName -> HuggingFaceRequestManager(project) else -> throw IllegalStateException("Unknown selected platform: $platform") } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt index c2267beb8..45581b8cf 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt @@ -62,14 +62,12 @@ class GrazieRequestManager(project: Project) : IJRequestManager(project) { } private fun getMessages(): List> { - val result = mutableListOf>() - chatHistory.forEach { + return chatHistory.map { val role = when (it.role) { ChatMessage.ChatRole.User -> "user" ChatMessage.ChatRole.Assistant -> "assistant" } - result.add(Pair(role, it.content)) + (role to it.content) } - return result } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt index 40e0c3fba..33138c4f8 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt @@ -1,9 +1,30 @@ package org.jetbrains.research.testspark.tools.llm.generation.openai -import org.jetbrains.research.testspark.core.data.ChatMessage +/** + * Adheres the naming of fields for OpenAI chat completion API and checks the correctness of a `role`. + *
+ * Use this class as a carrier of messages that should be sent to OpenAI API. + */ +data class OpenAIChatMessage(val role: String, val content: String) { + private companion object { + /** + * The API strictly defines the set of roles. + * The `function` role is omitted because it is already deprecated. + * + * See: https://platform.openai.com/docs/api-reference/chat/create + */ + val supportedRoles = listOf("user", "assistant", "system", "tool") + } + + init { + if (!supportedRoles.contains(role)) { + throw IllegalArgumentException("'$role' is not supported ${OpenAIChatMessage::class}. Available roles are: ${(supportedRoles.joinToString(", ") { "'$it'" })}") + } + } +} data class OpenAIRequestBody( val model: String, - val messages: List, + val messages: List, val stream: Boolean = true, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt index ed6607d3e..1d9d6a9a4 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt @@ -7,6 +7,7 @@ import com.intellij.openapi.project.Project import com.intellij.util.io.HttpRequests import com.intellij.util.io.HttpRequests.HttpStatusException import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle +import org.jetbrains.research.testspark.core.data.ChatMessage import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.TestsAssembler @@ -35,22 +36,29 @@ class OpenAIRequestManager(project: Project) : IJRequestManager(project) { errorMonitor: ErrorMonitor, ): SendResult { // Prepare the chat - val llmRequestBody = OpenAIRequestBody(LlmSettingsArguments(project).getModel(), chatHistory) + val messages = chatHistory.map { + val role = when (it.role) { + ChatMessage.ChatRole.User -> "user" + ChatMessage.ChatRole.Assistant -> "assistant" + } + OpenAIChatMessage(role, it.content) + } + + val llmRequestBody = OpenAIRequestBody(LlmSettingsArguments(project).getModel(), messages) var sendResult = SendResult.OK try { - httpRequest.connect { - it.write(GsonBuilder().create().toJson(llmRequestBody)) + httpRequest.connect { request -> + // send request to OpenAI API + request.write(GsonBuilder().create().toJson(llmRequestBody)) + + val connection = request.connection as HttpURLConnection // check response - when (val responseCode = (it.connection as HttpURLConnection).responseCode) { + when (val responseCode = connection.responseCode) { HttpURLConnection.HTTP_OK -> { - assembleLlmResponse( - httpRequest = it, - indicator, - testsAssembler, - ) + assembleLlmResponse(request, testsAssembler, indicator) } HttpURLConnection.HTTP_INTERNAL_ERROR -> { @@ -105,13 +113,12 @@ class OpenAIRequestManager(project: Project) : IJRequestManager(project) { */ private fun assembleLlmResponse( httpRequest: HttpRequests.Request, - indicator: CustomProgressIndicator, testsAssembler: TestsAssembler, + indicator: CustomProgressIndicator, ) { while (true) { if (ToolUtils.isProcessCanceled(indicator)) return - Thread.sleep(50L) var text = httpRequest.reader.readLine() if (text.isEmpty()) continue diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt index b1473b0c9..10aded741 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt @@ -3,12 +3,14 @@ package org.jetbrains.research.testspark.tools.llm.test import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.generation.llm.getClassWithTestCaseName +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM -import org.jetbrains.research.testspark.helpers.JavaClassBuilderHelper +import org.jetbrains.research.testspark.tools.TestClassCodeGeneratorFactory class JUnitTestSuitePresenter( private val project: Project, private val generatedTestsData: TestGenerationData, + private val language: SupportedLanguage, ) { /** * Returns a string representation of this object. @@ -34,12 +36,12 @@ class JUnitTestSuitePresenter( // Add each test testCases.forEach { testCase -> testBody += "$testCase\n" } - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, testFileName, testBody, imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -57,12 +59,12 @@ class JUnitTestSuitePresenter( testCaseIndex: Int, ): String = testSuite.run { - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, getClassWithTestCaseName(testCases[testCaseIndex].name), testCases[testCaseIndex].toStringWithoutExpectedException() + "\n", imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -81,12 +83,12 @@ class JUnitTestSuitePresenter( // Add each test (exclude expected exception) testCases.forEach { testCase -> testBody += "${testCase.toStringWithoutExpectedException()}\n" } - JavaClassBuilderHelper.generateCode( + TestClassCodeGeneratorFactory.create(language).generateCode( project, testFileName, testBody, imports, - packageString, + packageName, runWith, otherInfo, generatedTestsData, @@ -105,8 +107,8 @@ class JUnitTestSuitePresenter( fun getPrintablePackageString(testSuite: TestSuiteGeneratedByLLM): String { return testSuite.run { when { - packageString.isEmpty() || packageString.isBlank() -> "" - else -> packageString + packageName.isEmpty() || packageName.isBlank() -> "" + else -> packageName } } } diff --git a/src/main/resources/properties/llm/LLMDefaults.properties b/src/main/resources/properties/llm/LLMDefaults.properties index 156f15cbd..1eddae6e2 100644 --- a/src/main/resources/properties/llm/LLMDefaults.properties +++ b/src/main/resources/properties/llm/LLMDefaults.properties @@ -4,6 +4,10 @@ openAIModel= grazieName=AI Assistant JetBrains grazieToken= grazieModel= +huggingFaceName=HuggingFace +huggingFaceToken= +huggingFaceModel= +huggingFaceInitialSystemPrompt=You are a helpful and honest code and programming assistant. Please, respond concisely and truthfully. maxLLMRequest=3 maxInputParamsDepth=2 maxPolyDepth=2 diff --git a/src/main/resources/properties/llm/LLMMessages.properties b/src/main/resources/properties/llm/LLMMessages.properties index db087d5c1..3502840ab 100644 --- a/src/main/resources/properties/llm/LLMMessages.properties +++ b/src/main/resources/properties/llm/LLMMessages.properties @@ -14,4 +14,5 @@ grazieError=Grazie test generation feature is not available in this build. removeTemplateMessage=Choose another default template to remove this one. removeTemplateTitle=Can't Be Removed defaultPromptIsNotValidMessage=Default prompt is not valid. Fix it, please. -defaultPromptIsNotValidTitle=Incorrect Prompt State \ No newline at end of file +defaultPromptIsNotValidTitle=Incorrect Prompt State +hfServerError=The selected model may need an HF PRO subscription to use! \ No newline at end of file From 53829defa6cf76ab4ad7359cb230481eeeddd0dc Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 21:28:06 +0200 Subject: [PATCH 13/19] fixing bugs --- .../testspark/display/TestSuiteView.kt | 40 ++++++ .../display/java/JavaTestSuiteView.kt | 136 ++++++++++++++++++ .../display/kotlin/KotlinTestSuiteView.kt | 136 ++++++++++++++++++ .../java/JavaTestCaseDisplayService.kt | 2 +- .../kotlin/KotlinTestCaseDisplayService.kt | 2 +- 5 files changed, 314 insertions(+), 2 deletions(-) create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/display/TestSuiteView.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/display/java/JavaTestSuiteView.kt create mode 100644 src/main/kotlin/org/jetbrains/research/testspark/display/kotlin/KotlinTestSuiteView.kt diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TestSuiteView.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TestSuiteView.kt new file mode 100644 index 000000000..d131d0785 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TestSuiteView.kt @@ -0,0 +1,40 @@ +package org.jetbrains.research.testspark.display + +import javax.swing.JPanel + +interface TestSuiteView { + /** + * Updates the labels. + */ + fun updateTopLabels() + + /** + * Toggles check boxes so that they are either all selected or all not selected, + * depending on the provided parameter. + * + * @param selected whether the checkboxes have to be selected or not + */ + fun toggleAllCheckboxes(selected: Boolean) + + /** + * Removes all test cases from the cache and tool window UI. + */ + fun removeAllTestCases() + + /** + * Executes all test cases. + * + * This method presents a caution message to the user and asks for confirmation before executing the test cases. + * If the user confirms, it iterates through each test case panel factory and runs the corresponding test. + */ + fun runAllTestCases() + + /** + * Sets the array of TestCasePanelFactory objects. + * + * @param testCasePanelFactories The ArrayList containing the TestCasePanelFactory objects to be set. + */ + fun setTestCasePanelFactoriesArray(testCasePanelFactories: ArrayList) + fun getPanel(): JPanel + fun clear() +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/java/JavaTestSuiteView.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/java/JavaTestSuiteView.kt new file mode 100644 index 000000000..c17844868 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/java/JavaTestSuiteView.kt @@ -0,0 +1,136 @@ +package org.jetbrains.research.testspark.display + +import com.intellij.openapi.progress.ProgressIndicator +import com.intellij.openapi.progress.ProgressManager +import com.intellij.openapi.progress.Task +import com.intellij.openapi.project.Project +import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle +import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.display.custom.IJProgressIndicator +import org.jetbrains.research.testspark.display.strategies.TopButtonsPanelStrategy +import java.awt.Dimension +import java.util.LinkedList +import java.util.Queue +import javax.swing.Box +import javax.swing.BoxLayout +import javax.swing.JButton +import javax.swing.JLabel +import javax.swing.JOptionPane +import javax.swing.JPanel + +class JavaTestSuiteView(private val project: Project) : TestSuiteView { + private val testCasePanelFactories = arrayListOf() + + private var runAllButton: JButton = createRunAllTestButton() + private var selectAllButton: JButton = + IconButtonCreator.getButton(TestSparkIcons.selectAll, PluginLabelsBundle.get("selectAllTip")) + private var unselectAllButton: JButton = + IconButtonCreator.getButton(TestSparkIcons.unselectAll, PluginLabelsBundle.get("unselectAllTip")) + private var removeAllButton: JButton = + IconButtonCreator.getButton(TestSparkIcons.removeAll, PluginLabelsBundle.get("removeAllTip")) + + private var testsSelectedText: String = "${PluginLabelsBundle.get("testsSelected")}: %d/%d" + private var testsSelectedLabel: JLabel = JLabel(testsSelectedText) + + private val testsPassedText: String = "${PluginLabelsBundle.get("testsPassed")}: %d/%d" + private var testsPassedLabel: JLabel = JLabel(testsPassedText) + + override fun updateTopLabels() { + TopButtonsPanelStrategy.updateTopJavaLabels( + testCasePanelFactories, + testsSelectedLabel, + testsSelectedText, + project, + testsPassedLabel, + testsPassedText, + runAllButton, + ) + } + + override fun toggleAllCheckboxes(selected: Boolean) { + TopButtonsPanelStrategy.toggleAllJavaCheckboxes(selected, project) + } + + override fun removeAllTestCases() { + TopButtonsPanelStrategy.removeAllJavaTestCases(project) + } + + override fun runAllTestCases() { + val choice = JOptionPane.showConfirmDialog( + null, + PluginMessagesBundle.get("runCautionMessage"), + PluginMessagesBundle.get("confirmationTitle"), + JOptionPane.OK_CANCEL_OPTION, + JOptionPane.WARNING_MESSAGE, + ) + + if (choice == JOptionPane.CANCEL_OPTION) return + + runAllButton.isEnabled = false + + // add each test generation task to queue + val tasks: Queue<(CustomProgressIndicator) -> Unit> = LinkedList() + + for (testCasePanelFactory in testCasePanelFactories) { + testCasePanelFactory.addTask(tasks) + } + // run tasks one after each other + executeTasks(tasks) + } + + private fun executeTasks(tasks: Queue<(CustomProgressIndicator) -> Unit>) { + val nextTask = tasks.poll() + + nextTask?.let { task -> + ProgressManager.getInstance().run(object : Task.Backgroundable(project, "Test execution") { + override fun run(indicator: ProgressIndicator) { + task(IJProgressIndicator(indicator)) + } + + override fun onFinished() { + super.onFinished() + executeTasks(tasks) + } + }) + } + } + + override fun setTestCasePanelFactoriesArray(testCasePanelFactories: ArrayList) { + this.testCasePanelFactories.addAll(testCasePanelFactories) + } + + override fun getPanel(): JPanel { + val panel = JPanel() + panel.layout = BoxLayout(panel, BoxLayout.X_AXIS) + panel.preferredSize = Dimension(0, 30) + panel.add(Box.createRigidArea(Dimension(10, 0))) + panel.add(testsPassedLabel) + panel.add(Box.createRigidArea(Dimension(10, 0))) + panel.add(testsSelectedLabel) + panel.add(Box.createHorizontalGlue()) + panel.add(runAllButton) + panel.add(selectAllButton) + panel.add(unselectAllButton) + panel.add(removeAllButton) + + selectAllButton.addActionListener { toggleAllCheckboxes(true) } + unselectAllButton.addActionListener { toggleAllCheckboxes(false) } + removeAllButton.addActionListener { removeAllTestCases() } + runAllButton.addActionListener { runAllTestCases() } + + return panel + } + + override fun clear() { + testCasePanelFactories.clear() + } + + private fun createRunAllTestButton(): JButton { + val runTestButton = JButton(PluginLabelsBundle.get("runAll"), TestSparkIcons.runTest) + runTestButton.isOpaque = false + runTestButton.isContentAreaFilled = false + runTestButton.isBorderPainted = true + return runTestButton + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/kotlin/KotlinTestSuiteView.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/kotlin/KotlinTestSuiteView.kt new file mode 100644 index 000000000..68f624ad7 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/kotlin/KotlinTestSuiteView.kt @@ -0,0 +1,136 @@ +package org.jetbrains.research.testspark.display + +import com.intellij.openapi.progress.ProgressIndicator +import com.intellij.openapi.progress.ProgressManager +import com.intellij.openapi.progress.Task +import com.intellij.openapi.project.Project +import org.jetbrains.research.testspark.bundles.plugin.PluginLabelsBundle +import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle +import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.display.custom.IJProgressIndicator +import org.jetbrains.research.testspark.display.strategies.TopButtonsPanelStrategy +import java.awt.Dimension +import java.util.LinkedList +import java.util.Queue +import javax.swing.Box +import javax.swing.BoxLayout +import javax.swing.JButton +import javax.swing.JLabel +import javax.swing.JOptionPane +import javax.swing.JPanel + +class KotlinTestSuiteView(private val project: Project) : TestSuiteView { + private val testCasePanelFactories = arrayListOf() + + private var runAllButton: JButton = createRunAllTestButton() + private var selectAllButton: JButton = + IconButtonCreator.getButton(TestSparkIcons.selectAll, PluginLabelsBundle.get("selectAllTip")) + private var unselectAllButton: JButton = + IconButtonCreator.getButton(TestSparkIcons.unselectAll, PluginLabelsBundle.get("unselectAllTip")) + private var removeAllButton: JButton = + IconButtonCreator.getButton(TestSparkIcons.removeAll, PluginLabelsBundle.get("removeAllTip")) + + private var testsSelectedText: String = "${PluginLabelsBundle.get("testsSelected")}: %d/%d" + private var testsSelectedLabel: JLabel = JLabel(testsSelectedText) + + private val testsPassedText: String = "${PluginLabelsBundle.get("testsPassed")}: %d/%d" + private var testsPassedLabel: JLabel = JLabel(testsPassedText) + + override fun updateTopLabels() { + TopButtonsPanelStrategy.updateTopKotlinLabels( + testCasePanelFactories, + testsSelectedLabel, + testsSelectedText, + project, + testsPassedLabel, + testsPassedText, + runAllButton, + ) + } + + override fun toggleAllCheckboxes(selected: Boolean) { + TopButtonsPanelStrategy.toggleAllKotlinCheckboxes(selected, project) + } + + override fun removeAllTestCases() { + TopButtonsPanelStrategy.removeAllKotlinTestCases(project) + } + + override fun runAllTestCases() { + val choice = JOptionPane.showConfirmDialog( + null, + PluginMessagesBundle.get("runCautionMessage"), + PluginMessagesBundle.get("confirmationTitle"), + JOptionPane.OK_CANCEL_OPTION, + JOptionPane.WARNING_MESSAGE, + ) + + if (choice == JOptionPane.CANCEL_OPTION) return + + runAllButton.isEnabled = false + + // add each test generation task to queue + val tasks: Queue<(CustomProgressIndicator) -> Unit> = LinkedList() + + for (testCasePanelFactory in testCasePanelFactories) { + testCasePanelFactory.addTask(tasks) + } + // run tasks one after each other + executeTasks(tasks) + } + + private fun executeTasks(tasks: Queue<(CustomProgressIndicator) -> Unit>) { + val nextTask = tasks.poll() + + nextTask?.let { task -> + ProgressManager.getInstance().run(object : Task.Backgroundable(project, "Test execution") { + override fun run(indicator: ProgressIndicator) { + task(IJProgressIndicator(indicator)) + } + + override fun onFinished() { + super.onFinished() + executeTasks(tasks) + } + }) + } + } + + override fun setTestCasePanelFactoriesArray(testCasePanelFactories: ArrayList) { + this.testCasePanelFactories.addAll(testCasePanelFactories) + } + + override fun getPanel(): JPanel { + val panel = JPanel() + panel.layout = BoxLayout(panel, BoxLayout.X_AXIS) + panel.preferredSize = Dimension(0, 30) + panel.add(Box.createRigidArea(Dimension(10, 0))) + panel.add(testsPassedLabel) + panel.add(Box.createRigidArea(Dimension(10, 0))) + panel.add(testsSelectedLabel) + panel.add(Box.createHorizontalGlue()) + panel.add(runAllButton) + panel.add(selectAllButton) + panel.add(unselectAllButton) + panel.add(removeAllButton) + + selectAllButton.addActionListener { toggleAllCheckboxes(true) } + unselectAllButton.addActionListener { toggleAllCheckboxes(false) } + removeAllButton.addActionListener { removeAllTestCases() } + runAllButton.addActionListener { runAllTestCases() } + + return panel + } + + override fun clear() { + testCasePanelFactories.clear() + } + + private fun createRunAllTestButton(): JButton { + val runTestButton = JButton(PluginLabelsBundle.get("runAll"), TestSparkIcons.runTest) + runTestButton.isOpaque = false + runTestButton.isContentAreaFilled = false + runTestButton.isBorderPainted = true + return runTestButton + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt index 0dbc5009c..228ec3d7c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/java/JavaTestCaseDisplayService.kt @@ -67,7 +67,7 @@ class JavaTestCaseDisplayService(private val project: Project) : TestCaseDisplay private var mainPanel: JPanel = JPanel() - private val topButtonsPanelFactory = TopButtonsPanelFactory(project, SupportedLanguage.Java) + private val topButtonsPanelFactory = TopButtonsPanelFactory(project).create(SupportedLanguage.Java) private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt index a77edd16d..ba16bdd47 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/services/kotlin/KotlinTestCaseDisplayService.kt @@ -69,7 +69,7 @@ class KotlinTestCaseDisplayService(private val project: Project) : TestCaseDispl private var mainPanel: JPanel = JPanel() - private val topButtonsPanelFactory = TopButtonsPanelFactory(project, SupportedLanguage.Kotlin) + private val topButtonsPanelFactory = TopButtonsPanelFactory(project).create(SupportedLanguage.Kotlin) private var applyButton: JButton = JButton(PluginLabelsBundle.get("applyButton")) From a6c9cbabc618e0f3669b6a984f88c6e5cf8c1621 Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Mon, 29 Jul 2024 21:29:46 +0200 Subject: [PATCH 14/19] top buttons panel factory --- .../research/testspark/display/TopButtonsPanelFactory.kt | 2 +- .../research/testspark/display/java/JavaTestSuiteView.kt | 2 +- .../research/testspark/display/kotlin/KotlinTestSuiteView.kt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt index 1a5938be1..c3be1c2ff 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/TopButtonsPanelFactory.kt @@ -10,4 +10,4 @@ class TopButtonsPanelFactory(private val project: Project) { SupportedLanguage.Kotlin -> KotlinTestSuiteView(project) } } -} \ No newline at end of file +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/java/JavaTestSuiteView.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/java/JavaTestSuiteView.kt index c17844868..b76b92070 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/java/JavaTestSuiteView.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/java/JavaTestSuiteView.kt @@ -133,4 +133,4 @@ class JavaTestSuiteView(private val project: Project) : TestSuiteView { runTestButton.isBorderPainted = true return runTestButton } -} \ No newline at end of file +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/kotlin/KotlinTestSuiteView.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/kotlin/KotlinTestSuiteView.kt index 68f624ad7..0d85bef67 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/kotlin/KotlinTestSuiteView.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/kotlin/KotlinTestSuiteView.kt @@ -133,4 +133,4 @@ class KotlinTestSuiteView(private val project: Project) : TestSuiteView { runTestButton.isBorderPainted = true return runTestButton } -} \ No newline at end of file +} From e80adcb9399a0eb9f73ff30bc25c99ffb8a4310c Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Tue, 30 Jul 2024 20:27:32 +0200 Subject: [PATCH 15/19] first round of documenting the work --- CONTRIBUTING.md | 107 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d8d8cecd0..40b65d5c7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,15 @@ # TestSpark +## Table of contents +- [Description](#description) +- [Build project](#build-project) +- [Run IDE for UI tests](#run-IDE-for-ui-tests) +- [Plugin Configuration](#plugin-configuration-file) +- [Language Support Documentation](#language-support-documentation) +- [Classes](#classes) +- [Tests](#tests) + + ## Description In this document you can find the overall structure of TestSpark plugin. The classes are listed and their purpose is described. This section is intended for developers and contributors to TestSpark plugin. @@ -23,6 +33,103 @@ to include test generation using Grazie in the runIdeForUiTests process, you nee `` is generated by Space, which has access to Automatically generating unit tests maven packages. +--- + +## Language Support Documentation + +# CONTRIBUTORS.md + +## Language Support Documentation + +The TestSpark plugin supports automatic test generation for various Java and Kotlin programming languages and aims to support even more programming languages in the future. + +This document provides an overview of the existing implementation for Java and Kotlin support and guidelines for adding support for additional programming languages. + +## Key Components + +### 1. PSI Parsers + +The first step is to enable the collection of the appropriate information for the code under test. This part is responsible for working with the PSI (Program Structure Interface) generated by IntelliJ IDEA. It helps parse the part where the cursor is located, provides a choice of the code elements that are available for testing at its position, and then finds all the needed dependencies to make the prompt complete with all the necessary knowledge about the code under test. + +This part is the most granular but complex at the same time. + +The main reason for this is to include dependencies only for the languages we need. This avoids errors if the user does not have some languages that our plugin supports. For example, if we work with a Python project, we don't want to depend on Kotlin because it will cause an error if Kotlin isn't present. Additionally, we want to incrementally add dependencies on other languages for faster startup. For example, we do not want to fetch the dependency on Java when we work with TypeScript. Other benefits include better organization, easier maintenance, and clearer separation of concerns. As a side-bonus, the addition of new languages will be easier. + +**Module Dependencies:** + +- **langwrappers**: This is a foundational module for language extensions. +- ****: Depends on the `langwrappers` module to implement the ``-specific `PsiHelper` and `PsiHelperProvider`. +- **src/**: Depends on `langwrappers` because we want to use `PsiHelper` and other interfaces regardless of the current language. Depends on ``, to make `plugin.xml` aware of the implementations of the Extension Point. + +**Plugin Dependencies:** + +- The main `plugin.xml` file declares the `psiHelperProvider` extension point using the `com.intellij.lang.LanguageExtensionPoint` class. +- The language-specific modules extend this extension point to register their implementations. +- When the project is opened, we load the EPs needed to work with the current project. Then, using the `PsiHelperProvider` interface, we can get the appropriate `PsiHelper` class per file. + +**Implementation Details:** + +- **Common Module (`langwrappers`)**: + - Contains the `PsiHelper` interface, which provides the necessary methods to interact with `psiFile`. + - The `PsiHelperProvider` class includes a companion object to fetch the appropriate `PsiHelper` implementation based on the file's language. + +- ** Module**: + - Implements the `PsiHelper` and `PsiHelperProvider` classes, which provide -specific logic. + - Declares the extension point in `testspark-.xml`. + +To add new languages, create a separate module for this language and register its implementation as an extension of the `psiHelperProvider` EP. Then follow the template provided above. + +### 2. Prompt Generation + +When we know how to parse the code, we need to construct the prompt. + +For each language, adjust the prompt that goes to the LLM. Ensure that the language, framework platform, and mocking framework are defined correctly in: + +```kotlin +data class PromptConfiguration( + val desiredLanguage: String, + val desiredTestingPlatform: String, + val desiredMockingFramework: String, +) +``` + +Additionally, check that all the dependencies (collected by `PsiHelper` for the current strategy) are passed properly. `PromptGenerator` and `PromptBuilder` are responsible for this job. + +### 3. Parsing LLM Response + +When the LLM response to our prompt is received, we have to parse it. + +We want to retrieve test functions from the response, collect them separately (and all together) in the tmp folder, and check for compilation. + +The current structure of this part is located in: +- `kotlin/org/jetbrains/research/testspark/core/test` +- `kotlin/org/jetbrains/research/testspark/tools` + +It can be more easily understood with the following diagram: +![](https://private-user-images.githubusercontent.com/70476032/349256986-dc7e1ff9-a9a5-4bd2-a51f-ecbfabeb6cba.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjIzNTEyOTAsIm5iZiI6MTcyMjM1MDk5MCwicGF0aCI6Ii83MDQ3NjAzMi8zNDkyNTY5ODYtZGM3ZTFmZjktYTlhNS00YmQyLWE1MWYtZWNiZmFiZWI2Y2JhLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MzAlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzMwVDE0NDk1MFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWJjMDg3MWM2ZDA4MDJlZGUwNzliMzNkNzA3YWI4YTcwM2RmYTFjMmE1MGM4MjM5NjJiOGI2ZjgxNTE2OTU2YjQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.8OfRa1wJhDfFq3QT6h5yIjBh1VqB9UrrQfZGp0_SLDo) + +- `TestsAssembler`: Assembler class for generating and organizing test cases from the LLM response. +- `TestSuiteParser`: Extracts test cases from raw text and generates a test suite. +- `TestBodyPrinter`: Generates the body of a test function as a string. + +### 4. Compilation + +Before showing the code to the user, it should be checked for compilation. + +- `TestCompiler`: Compiles a list of test cases and returns the compilation result. + +### 5. UI Representation + +Once we parse the code generated by the LLM and confirm that the code is compilable, it should be presented in the UI. + +There are special interfaces that help to work with already parsed test classes and are specified for each language: + +- `TestCaseDisplayService`: Service responsible for the UI representation. +- `TestSuiteView`: Interface specific for working with buttons. +- `TestClassCodeAnalyzer`: Interface for retrieving information from test class code. +- `TestClassCodeGenerator`: Interface for generating and formatting test class code. +--- + ## Plugin Configuration File The plugin configuration file is `plugin.xml` which can be found in `src/main/resources/META-INF` directory. All declarations (such as actions, services, listeners) are present in this file. From dad01cc6969de0dc00c9dde25686992c13a41e9f Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Tue, 30 Jul 2024 21:05:41 +0200 Subject: [PATCH 16/19] added one more section --- CONTRIBUTING.md | 82 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 26 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 40b65d5c7..2e58a4f1a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -37,71 +37,91 @@ to include test generation using Grazie in the runIdeForUiTests process, you nee ## Language Support Documentation -# CONTRIBUTORS.md +The TestSpark plugin supports automatic test generation for various programming languages (currently Java and Kotlin) +and aims to support even more programming languages in the future. -## Language Support Documentation - -The TestSpark plugin supports automatic test generation for various Java and Kotlin programming languages and aims to support even more programming languages in the future. - -This document provides an overview of the existing implementation for Java and Kotlin support and guidelines for adding support for additional programming languages. +This document provides an overview of the existing implementation of Kotlin and Java support and guidelines for adding +more programming languages. ## Key Components ### 1. PSI Parsers -The first step is to enable the collection of the appropriate information for the code under test. This part is responsible for working with the PSI (Program Structure Interface) generated by IntelliJ IDEA. It helps parse the part where the cursor is located, provides a choice of the code elements that are available for testing at its position, and then finds all the needed dependencies to make the prompt complete with all the necessary knowledge about the code under test. +The first step is to enable the collection of the appropriate information for the code under test. This part is +responsible for working with the PSI (Program Structure Interface) generated by IntelliJ IDEA. It helps parse the part +where the cursor is located, provides a choice of the code elements that are available for testing at cursor's position. +Then find all the needed dependencies to make the prompt complete with all the necessary knowledge about the code under +test. This part is the most granular but complex at the same time. -The main reason for this is to include dependencies only for the languages we need. This avoids errors if the user does not have some languages that our plugin supports. For example, if we work with a Python project, we don't want to depend on Kotlin because it will cause an error if Kotlin isn't present. Additionally, we want to incrementally add dependencies on other languages for faster startup. For example, we do not want to fetch the dependency on Java when we work with TypeScript. Other benefits include better organization, easier maintenance, and clearer separation of concerns. As a side-bonus, the addition of new languages will be easier. +The main reason for this is to include dependencies only for the languages we need. This avoids errors if the user does +not have some languages that our plugin supports. _For example, if we work with a Python project, we don't want to depend +on Kotlin because it will cause an error if Kotlin isn't present._ + +Additionally, we want to incrementally add dependencies on other languages for faster startup. +_For example, we do not want to fetch the dependency on Java when we work with TypeScript._ +Other benefits include better organization, easier maintenance, and clearer separation of +concerns. As a side-bonus, the addition of new languages will be easier. **Module Dependencies:** - **langwrappers**: This is a foundational module for language extensions. -- ****: Depends on the `langwrappers` module to implement the ``-specific `PsiHelper` and `PsiHelperProvider`. -- **src/**: Depends on `langwrappers` because we want to use `PsiHelper` and other interfaces regardless of the current language. Depends on ``, to make `plugin.xml` aware of the implementations of the Extension Point. +- ****: Depends on the `langwrappers` module to implement the ``-specific `PsiHelper` + and `PsiHelperProvider`. +- **src/**: Depends on `langwrappers` because we want to use `PsiHelper` and other interfaces regardless of the current + language. Depends on ``, to make `plugin.xml` aware of the implementations of the Extension Point. **Plugin Dependencies:** -- The main `plugin.xml` file declares the `psiHelperProvider` extension point using the `com.intellij.lang.LanguageExtensionPoint` class. +- The main `plugin.xml` file declares the `psiHelperProvider` extension point using + the `com.intellij.lang.LanguageExtensionPoint` class. - The language-specific modules extend this extension point to register their implementations. -- When the project is opened, we load the EPs needed to work with the current project. Then, using the `PsiHelperProvider` interface, we can get the appropriate `PsiHelper` class per file. +- When the project is opened, we load the EPs needed to work with the current project. Then, using + the `PsiHelperProvider` interface, we can get the appropriate `PsiHelper` class per file. **Implementation Details:** - **Common Module (`langwrappers`)**: - Contains the `PsiHelper` interface, which provides the necessary methods to interact with `psiFile`. - - The `PsiHelperProvider` class includes a companion object to fetch the appropriate `PsiHelper` implementation based on the file's language. + - The `PsiHelperProvider` class includes a companion object to fetch the appropriate `PsiHelper` implementation + based on the file's language. - ** Module**: - - Implements the `PsiHelper` and `PsiHelperProvider` classes, which provide -specific logic. + - Implements the `PsiHelper` and `PsiHelperProvider` classes, which provide -specific + logic. - Declares the extension point in `testspark-.xml`. -To add new languages, create a separate module for this language and register its implementation as an extension of the `psiHelperProvider` EP. Then follow the template provided above. +To add new languages, create a separate module for this language and register its implementation as an extension of +the `psiHelperProvider` EP. Then follow the template provided above. ### 2. Prompt Generation When we know how to parse the code, we need to construct the prompt. -For each language, adjust the prompt that goes to the LLM. Ensure that the language, framework platform, and mocking framework are defined correctly in: +For each language, adjust the prompt that goes to the LLM. Ensure that the language, framework platform, and mocking +framework are defined correctly in: ```kotlin -data class PromptConfiguration( - val desiredLanguage: String, - val desiredTestingPlatform: String, - val desiredMockingFramework: String, +data class PromptConfiguration( + val desiredLanguage: String, + val desiredTestingPlatform: String, + val desiredMockingFramework: String, ) ``` -Additionally, check that all the dependencies (collected by `PsiHelper` for the current strategy) are passed properly. `PromptGenerator` and `PromptBuilder` are responsible for this job. +Additionally, check that all the dependencies (collected by `PsiHelper` for the current strategy) are passed +properly. `PromptGenerator` and `PromptBuilder` are responsible for this job. ### 3. Parsing LLM Response When the LLM response to our prompt is received, we have to parse it. -We want to retrieve test functions from the response, collect them separately (and all together) in the tmp folder, and check for compilation. +We want to retrieve test case, all the test functions and additional information like imports or supporting functions +from the response. The current structure of this part is located in: + - `kotlin/org/jetbrains/research/testspark/core/test` - `kotlin/org/jetbrains/research/testspark/tools` @@ -118,16 +138,26 @@ Before showing the code to the user, it should be checked for compilation. - `TestCompiler`: Compiles a list of test cases and returns the compilation result. -### 5. UI Representation +Here one should specify the appropriate compilation strategy for each language. With all the dependencies and build paths. -Once we parse the code generated by the LLM and confirm that the code is compilable, it should be presented in the UI. +### 5. UI Representation -There are special interfaces that help to work with already parsed test classes and are specified for each language: +Once the code generated by the LLM is checked for the compilation, it should be presented in the UI. -- `TestCaseDisplayService`: Service responsible for the UI representation. +- `TestCaseDisplayService`: Service responsible for the representation of all the UI components. - `TestSuiteView`: Interface specific for working with buttons. - `TestClassCodeAnalyzer`: Interface for retrieving information from test class code. - `TestClassCodeGenerator`: Interface for generating and formatting test class code. + +### 6. Running and saving tests + +We should be able to run all the tests in the UI and then save them to the desired folder. + +- `TestPersistentStorage`: Interface representing a contract for saving generated tests to a specified file system location. + +For Kotlin and Java, the `TestProcessor` implementation also allows saving the JaCoCo report to see the code coverage of +the test that will be saved. + --- ## Plugin Configuration File From 62a32def00514ee13df4cca27585386d93d48542 Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Tue, 30 Jul 2024 21:18:37 +0200 Subject: [PATCH 17/19] added the RuntimeException --- .../testspark/core/test/kotlin/KotlinTestCompiler.kt | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt index 8d61ce68e..63495e79b 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt @@ -12,6 +12,7 @@ class KotlinTestCompiler(libPaths: List, junitLibPaths: List) : override fun compileCode(path: String, projectBuildPath: String): Pair { log.info { "[KotlinTestCompiler] Compiling ${path.substringAfterLast('/')}" } + // TODO find the kotlinc if it is not in PATH val classPaths = "\"${getClassPaths(projectBuildPath)}\"" // Compile file val errorMsg = CommandLineRunner.run( @@ -23,7 +24,13 @@ class KotlinTestCompiler(libPaths: List, junitLibPaths: List) : ), ) - log.info { "Error message: '$errorMsg'" } + if (errorMsg.isNotEmpty()) { + log.info { "Error message: '$errorMsg'" } + if (errorMsg.contains("kotlinc: command not found'")) { + throw RuntimeException(errorMsg) + } + } + // No need to save the .class file for kotlin, so checking the error message is enough return Pair(errorMsg.isBlank(), errorMsg) From c9552c7940ce59ce54f853a60fc3597b547d0a8c Mon Sep 17 00:00:00 2001 From: Braun Ekaterina Date: Tue, 30 Jul 2024 21:24:37 +0200 Subject: [PATCH 18/19] klint --- .../research/testspark/core/test/kotlin/KotlinTestCompiler.kt | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt index 63495e79b..2c898a4b2 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinTestCompiler.kt @@ -31,7 +31,6 @@ class KotlinTestCompiler(libPaths: List, junitLibPaths: List) : } } - // No need to save the .class file for kotlin, so checking the error message is enough return Pair(errorMsg.isBlank(), errorMsg) } From 81f963eca637d7d5251d3ae408818b65aa674b8c Mon Sep 17 00:00:00 2001 From: Braun Ekaterina <70476032+Frosendroska@users.noreply.github.com> Date: Tue, 30 Jul 2024 21:47:17 +0200 Subject: [PATCH 19/19] Update CONTRIBUTING.md --- CONTRIBUTING.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2e58a4f1a..e71256d5b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -43,6 +43,12 @@ and aims to support even more programming languages in the future. This document provides an overview of the existing implementation of Kotlin and Java support and guidelines for adding more programming languages. +>How can I add support for a new programming language? +In brief, you need to extend all the necessary interfaces with implementations specific to the new language. +Below, you will find a detailed guide divided into six key components of the entire pipeline with the most +important interfaces addressing this goal. + + ## Key Components ### 1. PSI Parsers