diff --git a/data_prep/introspector.py b/data_prep/introspector.py
index 123154ac2b..bcbd859100 100755
--- a/data_prep/introspector.py
+++ b/data_prep/introspector.py
@@ -228,6 +228,15 @@ def query_introspector_cfg(project: str) -> dict:
return _get_data(resp, 'project', {})
+def query_introspector_source_file_path(project: str, func_sig: str) -> str:
+ """Queries FuzzIntrospector API for file path of |func_sig|."""
+ resp = _query_introspector(INTROSPECTOR_FUNCTION_SOURCE, {
+ 'project': project,
+ 'function_signature': func_sig
+ })
+ return _get_data(resp, 'filepath', '')
+
+
def query_introspector_function_source(project: str, func_sig: str) -> str:
"""Queries FuzzIntrospector API for source code of |func_sig|."""
resp = _query_introspector(INTROSPECTOR_FUNCTION_SOURCE, {
diff --git a/data_prep/project_context/context_introspector.py b/data_prep/project_context/context_introspector.py
index 0e094910ed..ccecdc630c 100644
--- a/data_prep/project_context/context_introspector.py
+++ b/data_prep/project_context/context_introspector.py
@@ -3,6 +3,7 @@
import logging
import os
+from difflib import SequenceMatcher
from typing import Any
from data_prep import introspector
@@ -217,3 +218,37 @@ def get_type_def(self, type_name: str) -> str:
type_names.append(new_type_name)
return type_def
+
+ def get_similar_header_file_paths(self, wrong_file: str) -> list[str]:
+ """Retrieves and finds 5 header file names closest to |wrong_name|."""
+ header_list = introspector.query_introspector_header_files(
+ self._benchmark.project)
+ candidate_header_scores = {
+ header:
+ SequenceMatcher(lambda x: x in ['_', '/', '-', '.'], wrong_file,
+ header).ratio() for header in header_list
+ }
+ candidate_headers = sorted(candidate_header_scores,
+ key=lambda x: candidate_header_scores[x],
+ reverse=True)
+ return candidate_headers[:5]
+
+ def get_target_function_file_path(self) -> str:
+ """Retrieves the header/source file of the function under test."""
+ # Step 1: Find a header file that shares the same name as the source file.
+ # TODO: Make this more robust, e.g., when header file and base file do not
+ # share the same basename.
+ source_file = introspector.query_introspector_source_file_path(
+ self._benchmark.project, self._benchmark.function_signature)
+ source_file_base, _ = os.path.splitext(os.path.basename(source_file))
+ header_list = introspector.query_introspector_header_files(
+ self._benchmark.project)
+ candidate_headers = [
+ header for header in header_list
+ if os.path.basename(header).startswith(source_file_base)
+ ]
+ if candidate_headers:
+ return candidate_headers[0]
+
+ # Step 2: Use the source file If it does not have a same-name-header.
+ return source_file
diff --git a/llm_toolkit/code_fixer.py b/llm_toolkit/code_fixer.py
index e939a823d1..781da4112f 100755
--- a/llm_toolkit/code_fixer.py
+++ b/llm_toolkit/code_fixer.py
@@ -29,6 +29,7 @@
ERROR_LINES = 20
NO_MEMBER_ERROR_REGEX = r"error: no member named '.*' in '([^':]*):?.*'"
+FILE_NOT_FOUND_ERROR_REGEX = r"fatal error: '([^']*)' file not found"
def parse_args():
@@ -410,8 +411,10 @@ def apply_llm_fix(ai_binary: str,
builder = prompt_builder.DefaultTemplateBuilder(fixer_model)
context = _collect_context(benchmark, errors)
+ instruction = _collect_instructions(benchmark, errors,
+ fuzz_target_source_code)
prompt = builder.build_fixer_prompt(benchmark, fuzz_target_source_code,
- error_desc, errors, context)
+ error_desc, errors, context, instruction)
prompt.save(prompt_path)
fixer_model.generate_code(prompt, response_dir)
@@ -441,6 +444,69 @@ def _collect_context_no_member(benchmark: benchmarklib.Benchmark,
return ci.get_type_def(target_type)
+def _collect_instructions(benchmark: benchmarklib.Benchmark, errors: list[str],
+ fuzz_target_source_code: str) -> str:
+ """Collects the useful instructions to fix the errors."""
+ if not errors:
+ return ''
+
+ instruction = ''
+ for error in errors:
+ instruction += _collect_instruction_file_not_found(benchmark, error,
+ fuzz_target_source_code)
+ return instruction
+
+
+def _collect_instruction_file_not_found(benchmark: benchmarklib.Benchmark,
+ error: str,
+ fuzz_target_source_code: str) -> str:
+ """Collects the useful instruction to fix 'file not found' errors."""
+ matched = re.search(FILE_NOT_FOUND_ERROR_REGEX, error)
+ if not matched:
+ return ''
+
+ # Step 1: Say the file does not exist, do not include it.
+ wrong_file = matched.group(1)
+ instruction = (
+ f'IMPORTANT: DO NOT include the header file {wrong_file} in the generated'
+ 'fuzz target again, the file does not exist in the project-under-test.\n')
+
+ ci = context_introspector.ContextRetriever(benchmark)
+ # Step 2: Suggest the header/source file of the function under test.
+ function_file = ci.get_target_function_file_path()
+ if f'#include "{function_file}"' in fuzz_target_source_code:
+ function_file_base_name = os.path.basename(function_file)
+
+ instruction += (
+ 'In the generated code, ensure that the path prefix of '
+ f'{function_file_base_name}
is consistent with other include '
+ f'statements related to the project ({benchmark.project}). For example,'
+ 'if another include statement is '
+ f'#include <{benchmark.project}/header.h>
, you must modify'
+ f' the path prefix in #include "{function_file}"
to match '
+ 'it, resulting in '
+ f'#include <{benchmark.project}/{function_file_base_name}>
.')
+ return instruction
+
+ if function_file:
+ instruction += (
+ f'If the non-existent {wrong_file} was included '
+ f'for the declaration of {benchmark.function_signature}
, '
+ 'you must replace it with the EXACT path of the actual file '
+ f'{function_file}. For example:\n'
+ f'\n#include "{function_file}"\n
\n')
+
+ # Step 2: Suggest similar alternatives.
+ similar_headers = ci.get_similar_header_file_paths(wrong_file)
+ if similar_headers:
+ statements = '\n'.join(
+ [f'#include "{header}"' for header in similar_headers])
+ instruction += (
+ 'Otherwise, consider replacing it with some of the following statements'
+ f'that may be correct alternatives:\n\n{statements}\n
\n')
+ return instruction
+
+
def main():
args = parse_args()
fix_all_targets(args.target_dir, args.project)
diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py
index ff33af2419..4f99023db7 100644
--- a/llm_toolkit/prompt_builder.py
+++ b/llm_toolkit/prompt_builder.py
@@ -139,6 +139,8 @@ def __init__(self,
template_dir, 'fixer_problem.txt')
self.fixer_context_template_file = self._find_template(
template_dir, 'fixer_context.txt')
+ self.fixer_instruction_template_file = self._find_template(
+ template_dir, 'fixer_instruction.txt')
self.triager_priming_template_file = self._find_template(
template_dir, 'triager_priming.txt')
self.triager_problem_template_file = self._find_template(
@@ -302,11 +304,12 @@ def build_fixer_prompt(self,
raw_code: str,
error_desc: Optional[str],
errors: list[str],
- context: str = '') -> prompts.Prompt:
+ context: str = '',
+ instruction: str = '') -> prompts.Prompt:
"""Prepares the code-fixing prompt."""
priming, priming_weight = self._format_fixer_priming(benchmark)
problem = self._format_fixer_problem(raw_code, error_desc, errors,
- priming_weight, context)
+ priming_weight, context, instruction)
self._prepare_prompt(priming, problem)
return self._prompt
@@ -328,7 +331,7 @@ def _format_fixer_priming(self, benchmark: Benchmark) -> Tuple[str, int]:
def _format_fixer_problem(self, raw_code: str, error_desc: Optional[str],
errors: list[str], priming_weight: int,
- context: str) -> str:
+ context: str, instruction: str) -> str:
"""Formats a problem for code fixer based on the template."""
with open(self.fixer_problem_template_file) as f:
problem = f.read().strip()
@@ -346,6 +349,12 @@ def _format_fixer_problem(self, raw_code: str, error_desc: Optional[str],
context = context_template.replace('{CONTEXT_SOURCE_CODE}', context)
problem = problem.replace('{CONTEXT}', context)
+ if instruction:
+ with open(self.fixer_instruction_template_file) as f:
+ instruction_template = f.read().strip()
+ instruction = instruction_template.replace('{INSTRUCTION}', instruction)
+ problem = problem.replace('{INSTRUCTION}', instruction)
+
problem_prompt = self._prompt.create_prompt_piece(problem, 'user')
template_piece = self._prompt.create_prompt_piece('{ERROR_MESSAGES}',
'user')
diff --git a/prompts/template_xml/fixer_instruction.txt b/prompts/template_xml/fixer_instruction.txt
new file mode 100644
index 0000000000..a46da218df
--- /dev/null
+++ b/prompts/template_xml/fixer_instruction.txt
@@ -0,0 +1,4 @@
+Below are instructions to assist you in fixing the error.
+
+{INSTRUCTION}
+
diff --git a/prompts/template_xml/fixer_problem.txt b/prompts/template_xml/fixer_problem.txt
index e182afe1ee..8f91b6e2f5 100644
--- a/prompts/template_xml/fixer_problem.txt
+++ b/prompts/template_xml/fixer_problem.txt
@@ -10,6 +10,7 @@ Below is the error to fix:
{CONTEXT}
+{INSTRUCTION}
Fix code:
1. Consider possible solutions for the issues listed above.