diff --git a/README.md b/README.md index 3a49e1278..2498b617a 100644 --- a/README.md +++ b/README.md @@ -127,8 +127,7 @@ poetry run cover-agent \ --coverage-type "cobertura" \ --desired-coverage 70 \ --max-iterations 1 \ - --openai-model "gpt-4o" \ - --additional-instructions "Since I am using a test class each line of code (including the first line), In your response, will need to be prepended with 4 whitespaces. This is extremely important to check to make sure every line returned contains that 4 whitespace indent otherwise my code will not run." + --openai-model "gpt-4o" ``` Note: If you are using Poetry then use the `poetry run python -m cover-agent` command instead of the `cover-agent` run command. diff --git a/cover_agent/FilePreprocessor.py b/cover_agent/FilePreprocessor.py new file mode 100644 index 000000000..4254982fb --- /dev/null +++ b/cover_agent/FilePreprocessor.py @@ -0,0 +1,49 @@ +import ast +import textwrap + + +class FilePreprocessor: + def __init__(self, path_to_file): + self.path_to_file = path_to_file + + # List of rules/action key pair. + # Add your new rule and how to process the text (function) here + self.rules = [(self._is_python_file, self._process_if_python)] + + def process_file(self, text: str) -> str: + """ + Process the text based on the internal rules. + """ + for condition, action in self.rules: + if condition(): + return action(text) + return text # Return the text unchanged if no rules apply + + def _is_python_file(self) -> bool: + """ + Rule to check if the file is a Python file. + """ + return self.path_to_file.endswith(".py") + + def _process_if_python(self, text: str) -> str: + """ + Action to process Python files by checking for class definitions and indenting if found. + """ + if self._contains_class_definition(): + return textwrap.indent(text, " ") + return text + + def _contains_class_definition(self) -> bool: + """ + Check if the file contains a Python class definition using the ast module. + """ + try: + with open(self.path_to_file, "r") as file: + content = file.read() + parsed_ast = ast.parse(content) + for node in ast.walk(parsed_ast): + if isinstance(node, ast.ClassDef): + return True + except SyntaxError as e: + print(f"Syntax error when parsing the file: {e}") + return False diff --git a/cover_agent/PromptBuilder.py b/cover_agent/PromptBuilder.py index 6340fb568..5769870b1 100644 --- a/cover_agent/PromptBuilder.py +++ b/cover_agent/PromptBuilder.py @@ -20,6 +20,8 @@ {failed_test_runs} ``` """ + + class PromptBuilder: def __init__( @@ -60,9 +62,23 @@ def __init__( self.code_coverage_report = code_coverage_report # Conditionally fill in optional sections - self.included_files = ADDITIONAL_INCLUDES_TEXT.format(included_files=included_files) if included_files else included_files - self.additional_instructions = ADDITIONAL_INSTRUCTIONS_TEXT.format(additional_instructions=additional_instructions) if additional_instructions else additional_instructions - self.failed_test_runs = FAILED_TESTS_TEXT.format(failed_test_runs=failed_test_runs) if failed_test_runs else failed_test_runs + self.included_files = ( + ADDITIONAL_INCLUDES_TEXT.format(included_files=included_files) + if included_files + else included_files + ) + self.additional_instructions = ( + ADDITIONAL_INSTRUCTIONS_TEXT.format( + additional_instructions=additional_instructions + ) + if additional_instructions + else additional_instructions + ) + self.failed_test_runs = ( + FAILED_TESTS_TEXT.format(failed_test_runs=failed_test_runs) + if failed_test_runs + else failed_test_runs + ) def _read_file(self, file_path): """ diff --git a/cover_agent/UnitTestGenerator.py b/cover_agent/UnitTestGenerator.py index 89f811dbf..45f083d73 100644 --- a/cover_agent/UnitTestGenerator.py +++ b/cover_agent/UnitTestGenerator.py @@ -6,6 +6,7 @@ from cover_agent.CustomLogger import CustomLogger from cover_agent.PromptBuilder import PromptBuilder from cover_agent.AICaller import AICaller +from cover_agent.FilePreprocessor import FilePreprocessor class UnitTestGenerator: @@ -54,6 +55,8 @@ def __init__( # Get the logger instance from CustomLogger self.logger = CustomLogger.get_logger(__name__) + # States to maintain within this class + self.preprocessor = FilePreprocessor(self.test_file_path) self.failed_test_runs = [] # Run coverage and build the prompt @@ -140,7 +143,9 @@ def build_prompt(self): if not self.failed_test_runs: failed_test_runs_value = "" else: - failed_test_runs_value = json.dumps(self.failed_test_runs).replace("\\n", "\n") + failed_test_runs_value = json.dumps(self.failed_test_runs).replace( + "\\n", "\n" + ) # Call PromptBuilder to build the prompt prompt = PromptBuilder( @@ -172,7 +177,7 @@ def generate_tests(self, LLM_model="gpt-4o", max_tokens=4096, dry_run=False): # We want to remove them and split up the tests into a list of tests response = ai_caller.call_model(prompt=self.prompt, max_tokens=max_tokens) - # Split the response into a list of tests and strip off the trailing whitespaces + # Split the response into a list of tests and strip off the trailing whitespaces # (as we sometimes anticipate indentations in the returned code from the LLM) tests = response.split("```") return [test.rstrip() for test in tests if test.rstrip()] @@ -191,13 +196,16 @@ def validate_test(self, generated_test: str): dict: A dictionary containing the test result status, reason for failure (if any), stdout, stderr, exit code, and the test itself. """ + # Step 0: Run the test through the preprocessor rule set + processed_test = self.preprocessor.process_file(generated_test) + # Step 1: Append the generated test to the test file and save the original content with open(self.test_file_path, "r+") as test_file: original_content = test_file.read() # Store original content test_file.write( "\n" + ("\n" if not original_content.endswith("\n") else "") - + generated_test + + processed_test + "\n" ) # Append the new test at the end @@ -223,7 +231,9 @@ def validate_test(self, generated_test: str): "stdout": stdout, "test": generated_test, } - self.failed_test_runs.append(fail_details["test"]) # Append failure details to the list + self.failed_test_runs.append( + fail_details["test"] + ) # Append failure details to the list return fail_details # If test passed, check for coverage increase @@ -253,7 +263,9 @@ def validate_test(self, generated_test: str): "stdout": stdout, "test": generated_test, } - self.failed_test_runs.append(fail_details["test"]) # Append failure details to the list + self.failed_test_runs.append( + fail_details["test"] + ) # Append failure details to the list return fail_details except Exception as e: # Handle errors gracefully @@ -269,7 +281,9 @@ def validate_test(self, generated_test: str): "stdout": stdout, "test": generated_test, } - self.failed_test_runs.append(fail_details["test"]) # Append failure details to the list + self.failed_test_runs.append( + fail_details["test"] + ) # Append failure details to the list return fail_details # If everything passed and coverage increased, update current coverage and log success diff --git a/cover_agent/main.py b/cover_agent/main.py index c72bbdf3f..5832e9213 100644 --- a/cover_agent/main.py +++ b/cover_agent/main.py @@ -36,7 +36,7 @@ def parse_args(): "--included-files", default=None, nargs="*", - help="List of files to include in the coverage. For example, \"--included-files library1.c library2.c.\" Default: %(default)s.", + help='List of files to include in the coverage. For example, "--included-files library1.c library2.c." Default: %(default)s.', ) parser.add_argument( "--coverage-type", @@ -134,9 +134,11 @@ def main(): and iteration_count < args.max_iterations ): # Provide coverage feedback to user - logger.info(f"Current Coverage: {round(test_gen.current_coverage * 100, 2)}%") + logger.info( + f"Current Coverage: {round(test_gen.current_coverage * 100, 2)}%" + ) logger.info(f"Desired Coverage: {test_gen.desired_coverage}%") - + # Generate tests by making a call to the LLM generated_tests = test_gen.generate_tests( LLM_model=args.openai_model, max_tokens=4096 @@ -154,7 +156,9 @@ def main(): iteration_count += 1 if iteration_count == args.max_iterations: - logger.info("Reached maximum iteration limit without achieving desired coverage.") + logger.info( + "Reached maximum iteration limit without achieving desired coverage." + ) # Dump the test results to a report ReportGenerator.generate_report(test_results_list, "test_results.html") diff --git a/cover_agent/version.txt b/cover_agent/version.txt index db7a48047..28d007539 100644 --- a/cover_agent/version.txt +++ b/cover_agent/version.txt @@ -1 +1 @@ -0.1.31 +0.1.32 diff --git a/tests/test_FilePreprocessor.py b/tests/test_FilePreprocessor.py new file mode 100644 index 000000000..3b763739a --- /dev/null +++ b/tests/test_FilePreprocessor.py @@ -0,0 +1,53 @@ +import pytest +import tempfile +import textwrap +from cover_agent.FilePreprocessor import FilePreprocessor + + +class TestFilePreprocessor: + # Test for a C file + def test_c_file(self): + with tempfile.NamedTemporaryFile(delete=False, suffix=".c") as tmp: + preprocessor = FilePreprocessor(tmp.name) + input_text = "Lorem ipsum dolor sit amet,\nconsectetur adipiscing elit,\nsed do eiusmod tempor incididunt." + processed_text = preprocessor.process_file(input_text) + assert ( + processed_text == input_text + ), "C file processing should not alter the text." + + # Test for a Python file with only a function + def test_py_file_with_function_only(self): + with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as tmp: + tmp.write(b"def function():\n pass\n") + tmp.close() + preprocessor = FilePreprocessor(tmp.name) + input_text = "Lorem ipsum dolor sit amet,\nconsectetur adipiscing elit,\nsed do eiusmod tempor incididunt." + processed_text = preprocessor.process_file(input_text) + assert ( + processed_text == input_text + ), "Python file without class should not alter the text." + + # Test for a Python file with a comment that looks like a class definition + def test_py_file_with_commented_class(self): + with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as tmp: + tmp.write(b"# class myPythonFile:\n pass\n") + tmp.close() + preprocessor = FilePreprocessor(tmp.name) + input_text = "Lorem ipsum dolor sit amet,\nconsectetur adipiscing elit,\nsed do eiusmod tempor incididunt." + processed_text = preprocessor.process_file(input_text) + assert ( + processed_text == input_text + ), "Commented class definition should not trigger processing." + + # Test for a Python file with an actual class definition + def test_py_file_with_class(self): + with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as tmp: + tmp.write(b"class MyClass:\n def method(self):\n pass\n") + tmp.close() + preprocessor = FilePreprocessor(tmp.name) + input_text = "Lorem ipsum dolor sit amet,\nconsectetur adipiscing elit,\nsed do eiusmod tempor incididunt." + processed_text = preprocessor.process_file(input_text) + expected_output = textwrap.indent(input_text, " ") + assert ( + processed_text == expected_output + ), "Python file with class should indent the text." diff --git a/tests/test_PromptBuilder.py b/tests/test_PromptBuilder.py index b94df0ef5..2297df4da 100644 --- a/tests/test_PromptBuilder.py +++ b/tests/test_PromptBuilder.py @@ -2,6 +2,7 @@ from unittest.mock import patch, mock_open from cover_agent.PromptBuilder import PromptBuilder + class TestPromptBuilder: @pytest.fixture(autouse=True) def setup_method(self, monkeypatch): @@ -11,7 +12,10 @@ def setup_method(self, monkeypatch): def test_initialization_reads_file_contents(self): builder = PromptBuilder( - "cover_agent/prompt_template.md", "source_path", "test_path", "dummy content" + "cover_agent/prompt_template.md", + "source_path", + "test_path", + "dummy content", ) assert builder.prompt_template == "dummy content" assert builder.source_file == "dummy content" @@ -27,7 +31,7 @@ def test_build_prompt_replaces_placeholders_correctly(self): "coverage_report", "Included Files Content", "Additional Instructions Content", - "Failed Test Runs Content" + "Failed Test Runs Content", ) builder.prompt_template = "Template: {source_file}, Test: {test_file}, Coverage: {code_coverage_report}, Includes: {additional_includes_section}, Instructions: {additional_instructions_text}, Failed Tests: {failed_tests_section}" builder.source_file = "Source Content" @@ -48,7 +52,10 @@ def mock_open_raise(*args, **kwargs): monkeypatch.setattr("builtins.open", mock_open_raise) builder = PromptBuilder( - "cover_agent/prompt_template.md", "source_path", "test_path", "coverage_report" + "cover_agent/prompt_template.md", + "source_path", + "test_path", + "coverage_report", ) assert "Error reading cover_agent/prompt_template.md" in builder.prompt_template assert "Error reading source_path" in builder.source_file @@ -62,7 +69,7 @@ def test_empty_included_files_section_not_in_prompt(self, monkeypatch): source_file_path="source_path", test_file_path="test_path", code_coverage_report="coverage_report", - included_files="Included Files Content" + included_files="Included Files Content", ) # Directly read the real file content for the prompt template with open("cover_agent/prompt_template.md", "r") as f: @@ -83,7 +90,7 @@ def test_non_empty_included_files_section_in_prompt(self, monkeypatch): source_file_path="source_path", test_file_path="test_path", code_coverage_report="coverage_report", - included_files="Included Files Content" + included_files="Included Files Content", ) # Directly read the real file content for the prompt template @@ -106,7 +113,7 @@ def test_empty_additional_instructions_section_not_in_prompt(self, monkeypatch): source_file_path="source_path", test_file_path="test_path", code_coverage_report="coverage_report", - additional_instructions="" + additional_instructions="", ) # Directly read the real file content for the prompt template with open("cover_agent/prompt_template.md", "r") as f: @@ -126,7 +133,7 @@ def test_empty_failed_test_runs_section_not_in_prompt(self, monkeypatch): source_file_path="source_path", test_file_path="test_path", code_coverage_report="coverage_report", - failed_test_runs="" + failed_test_runs="", ) # Directly read the real file content for the prompt template with open("cover_agent/prompt_template.md", "r") as f: @@ -146,7 +153,7 @@ def test_non_empty_additional_instructions_section_in_prompt(self, monkeypatch): source_file_path="source_path", test_file_path="test_path", code_coverage_report="coverage_report", - additional_instructions="Additional Instructions Content" + additional_instructions="Additional Instructions Content", ) # Directly read the real file content for the prompt template with open("cover_agent/prompt_template.md", "r") as f: @@ -167,7 +174,7 @@ def test_non_empty_failed_test_runs_section_in_prompt(self, monkeypatch): source_file_path="source_path", test_file_path="test_path", code_coverage_report="coverage_report", - failed_test_runs="Failed Test Runs Content" + failed_test_runs="Failed Test Runs Content", ) # Directly read the real file content for the prompt template with open("cover_agent/prompt_template.md", "r") as f: