diff --git a/data_prep/introspector.py b/data_prep/introspector.py index 123154ac2b..bcbd859100 100755 --- a/data_prep/introspector.py +++ b/data_prep/introspector.py @@ -228,6 +228,15 @@ def query_introspector_cfg(project: str) -> dict: return _get_data(resp, 'project', {}) +def query_introspector_source_file_path(project: str, func_sig: str) -> str: + """Queries FuzzIntrospector API for file path of |func_sig|.""" + resp = _query_introspector(INTROSPECTOR_FUNCTION_SOURCE, { + 'project': project, + 'function_signature': func_sig + }) + return _get_data(resp, 'filepath', '') + + def query_introspector_function_source(project: str, func_sig: str) -> str: """Queries FuzzIntrospector API for source code of |func_sig|.""" resp = _query_introspector(INTROSPECTOR_FUNCTION_SOURCE, { diff --git a/data_prep/project_context/context_introspector.py b/data_prep/project_context/context_introspector.py index 0e094910ed..ccecdc630c 100644 --- a/data_prep/project_context/context_introspector.py +++ b/data_prep/project_context/context_introspector.py @@ -3,6 +3,7 @@ import logging import os +from difflib import SequenceMatcher from typing import Any from data_prep import introspector @@ -217,3 +218,37 @@ def get_type_def(self, type_name: str) -> str: type_names.append(new_type_name) return type_def + + def get_similar_header_file_paths(self, wrong_file: str) -> list[str]: + """Retrieves and finds 5 header file names closest to |wrong_name|.""" + header_list = introspector.query_introspector_header_files( + self._benchmark.project) + candidate_header_scores = { + header: + SequenceMatcher(lambda x: x in ['_', '/', '-', '.'], wrong_file, + header).ratio() for header in header_list + } + candidate_headers = sorted(candidate_header_scores, + key=lambda x: candidate_header_scores[x], + reverse=True) + return candidate_headers[:5] + + def get_target_function_file_path(self) -> str: + """Retrieves the header/source file of the function under test.""" + # Step 1: Find a header file that shares the same name as the source file. + # TODO: Make this more robust, e.g., when header file and base file do not + # share the same basename. + source_file = introspector.query_introspector_source_file_path( + self._benchmark.project, self._benchmark.function_signature) + source_file_base, _ = os.path.splitext(os.path.basename(source_file)) + header_list = introspector.query_introspector_header_files( + self._benchmark.project) + candidate_headers = [ + header for header in header_list + if os.path.basename(header).startswith(source_file_base) + ] + if candidate_headers: + return candidate_headers[0] + + # Step 2: Use the source file If it does not have a same-name-header. + return source_file diff --git a/llm_toolkit/code_fixer.py b/llm_toolkit/code_fixer.py index e939a823d1..781da4112f 100755 --- a/llm_toolkit/code_fixer.py +++ b/llm_toolkit/code_fixer.py @@ -29,6 +29,7 @@ ERROR_LINES = 20 NO_MEMBER_ERROR_REGEX = r"error: no member named '.*' in '([^':]*):?.*'" +FILE_NOT_FOUND_ERROR_REGEX = r"fatal error: '([^']*)' file not found" def parse_args(): @@ -410,8 +411,10 @@ def apply_llm_fix(ai_binary: str, builder = prompt_builder.DefaultTemplateBuilder(fixer_model) context = _collect_context(benchmark, errors) + instruction = _collect_instructions(benchmark, errors, + fuzz_target_source_code) prompt = builder.build_fixer_prompt(benchmark, fuzz_target_source_code, - error_desc, errors, context) + error_desc, errors, context, instruction) prompt.save(prompt_path) fixer_model.generate_code(prompt, response_dir) @@ -441,6 +444,69 @@ def _collect_context_no_member(benchmark: benchmarklib.Benchmark, return ci.get_type_def(target_type) +def _collect_instructions(benchmark: benchmarklib.Benchmark, errors: list[str], + fuzz_target_source_code: str) -> str: + """Collects the useful instructions to fix the errors.""" + if not errors: + return '' + + instruction = '' + for error in errors: + instruction += _collect_instruction_file_not_found(benchmark, error, + fuzz_target_source_code) + return instruction + + +def _collect_instruction_file_not_found(benchmark: benchmarklib.Benchmark, + error: str, + fuzz_target_source_code: str) -> str: + """Collects the useful instruction to fix 'file not found' errors.""" + matched = re.search(FILE_NOT_FOUND_ERROR_REGEX, error) + if not matched: + return '' + + # Step 1: Say the file does not exist, do not include it. + wrong_file = matched.group(1) + instruction = ( + f'IMPORTANT: DO NOT include the header file {wrong_file} in the generated' + 'fuzz target again, the file does not exist in the project-under-test.\n') + + ci = context_introspector.ContextRetriever(benchmark) + # Step 2: Suggest the header/source file of the function under test. + function_file = ci.get_target_function_file_path() + if f'#include "{function_file}"' in fuzz_target_source_code: + function_file_base_name = os.path.basename(function_file) + + instruction += ( + 'In the generated code, ensure that the path prefix of ' + f'{function_file_base_name} is consistent with other include ' + f'statements related to the project ({benchmark.project}). For example,' + 'if another include statement is ' + f'#include <{benchmark.project}/header.h>, you must modify' + f' the path prefix in #include "{function_file}" to match ' + 'it, resulting in ' + f'#include <{benchmark.project}/{function_file_base_name}>.') + return instruction + + if function_file: + instruction += ( + f'If the non-existent {wrong_file} was included ' + f'for the declaration of {benchmark.function_signature}, ' + 'you must replace it with the EXACT path of the actual file ' + f'{function_file}. For example:\n' + f'\n#include "{function_file}"\n\n') + + # Step 2: Suggest similar alternatives. + similar_headers = ci.get_similar_header_file_paths(wrong_file) + if similar_headers: + statements = '\n'.join( + [f'#include "{header}"' for header in similar_headers]) + instruction += ( + 'Otherwise, consider replacing it with some of the following statements' + f'that may be correct alternatives:\n\n{statements}\n\n') + return instruction + + def main(): args = parse_args() fix_all_targets(args.target_dir, args.project) diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py index ff33af2419..4f99023db7 100644 --- a/llm_toolkit/prompt_builder.py +++ b/llm_toolkit/prompt_builder.py @@ -139,6 +139,8 @@ def __init__(self, template_dir, 'fixer_problem.txt') self.fixer_context_template_file = self._find_template( template_dir, 'fixer_context.txt') + self.fixer_instruction_template_file = self._find_template( + template_dir, 'fixer_instruction.txt') self.triager_priming_template_file = self._find_template( template_dir, 'triager_priming.txt') self.triager_problem_template_file = self._find_template( @@ -302,11 +304,12 @@ def build_fixer_prompt(self, raw_code: str, error_desc: Optional[str], errors: list[str], - context: str = '') -> prompts.Prompt: + context: str = '', + instruction: str = '') -> prompts.Prompt: """Prepares the code-fixing prompt.""" priming, priming_weight = self._format_fixer_priming(benchmark) problem = self._format_fixer_problem(raw_code, error_desc, errors, - priming_weight, context) + priming_weight, context, instruction) self._prepare_prompt(priming, problem) return self._prompt @@ -328,7 +331,7 @@ def _format_fixer_priming(self, benchmark: Benchmark) -> Tuple[str, int]: def _format_fixer_problem(self, raw_code: str, error_desc: Optional[str], errors: list[str], priming_weight: int, - context: str) -> str: + context: str, instruction: str) -> str: """Formats a problem for code fixer based on the template.""" with open(self.fixer_problem_template_file) as f: problem = f.read().strip() @@ -346,6 +349,12 @@ def _format_fixer_problem(self, raw_code: str, error_desc: Optional[str], context = context_template.replace('{CONTEXT_SOURCE_CODE}', context) problem = problem.replace('{CONTEXT}', context) + if instruction: + with open(self.fixer_instruction_template_file) as f: + instruction_template = f.read().strip() + instruction = instruction_template.replace('{INSTRUCTION}', instruction) + problem = problem.replace('{INSTRUCTION}', instruction) + problem_prompt = self._prompt.create_prompt_piece(problem, 'user') template_piece = self._prompt.create_prompt_piece('{ERROR_MESSAGES}', 'user') diff --git a/prompts/template_xml/fixer_instruction.txt b/prompts/template_xml/fixer_instruction.txt new file mode 100644 index 0000000000..a46da218df --- /dev/null +++ b/prompts/template_xml/fixer_instruction.txt @@ -0,0 +1,4 @@ +Below are instructions to assist you in fixing the error. + +{INSTRUCTION} + diff --git a/prompts/template_xml/fixer_problem.txt b/prompts/template_xml/fixer_problem.txt index e182afe1ee..8f91b6e2f5 100644 --- a/prompts/template_xml/fixer_problem.txt +++ b/prompts/template_xml/fixer_problem.txt @@ -10,6 +10,7 @@ Below is the error to fix: {CONTEXT} +{INSTRUCTION} Fix code: 1. Consider possible solutions for the issues listed above.