From 66fff5888dc445d569f3f6bc5a894a2bca54878a Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Wed, 4 Dec 2024 22:06:25 +0000 Subject: [PATCH] [JVM] Add retry logic for JVM (#740) This PR add back the retry logic for JVM when the first generated harness is failed to build. --------- Signed-off-by: Arthur Chan --- experiment/builder_runner.py | 4 +- experiment/evaluator.py | 14 +-- llm_toolkit/code_fixer.py | 56 ++++++++--- llm_toolkit/prompt_builder.py | 76 +++++++++++++++ .../jvm_requirement_error_fixing.txt | 93 +++++++++++++++++++ 5 files changed, 219 insertions(+), 24 deletions(-) create mode 100644 prompts/template_xml/jvm_requirement_error_fixing.txt diff --git a/experiment/builder_runner.py b/experiment/builder_runner.py index 50879f6d8..bd70ea4c6 100644 --- a/experiment/builder_runner.py +++ b/experiment/builder_runner.py @@ -493,7 +493,7 @@ def build_and_run_local( if not build_result.succeeded: errors = code_fixer.extract_error_message(benchmark_log_path, - project_target_name) + project_target_name, language) build_result.errors = errors return build_result, None @@ -980,7 +980,7 @@ def build_and_run_cloud( if not build_result.succeeded: errors = code_fixer.extract_error_message( self.work_dirs.build_logs_target(generated_target_name, iteration), - os.path.basename(self.benchmark.target_path)) + os.path.basename(self.benchmark.target_path), language) build_result.errors = errors logger.info('Cloud evaluation of %s indicates a failure: %s', os.path.realpath(target_path), errors) diff --git a/experiment/evaluator.py b/experiment/evaluator.py index 6c2259ff4..9ddd3e27e 100644 --- a/experiment/evaluator.py +++ b/experiment/evaluator.py @@ -281,9 +281,9 @@ def _fix_generated_fuzz_target(self, ai_binary: str, target_path: str, iteration: int, build_result: BuildResult, run_result: Optional[RunResult], - dual_logger: _Logger): + dual_logger: _Logger, language: str): """Fixes the generated fuzz target.""" - if build_result.succeeded: + if build_result.succeeded and not language == 'jvm': if run_result: error_desc, errors = run_result.semantic_check.get_error_info() else: @@ -293,7 +293,8 @@ def _fix_generated_fuzz_target(self, ai_binary: str, else: error_desc, errors = None, build_result.errors code_fixer.llm_fix(ai_binary, target_path, self.benchmark, iteration, - error_desc, errors, self.builder_runner.fixer_model_name) + error_desc, errors, self.builder_runner.fixer_model_name, + language) shutil.copyfile( target_path, os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects', @@ -393,9 +394,7 @@ def check_target(self, ai_binary, target_path: str) -> Result: # Exit cond 2: fix limit is reached. break - # 2. Fixing generated driver. Skipped for jvm projects. - if self.benchmark.language == 'jvm': - break + # 2. Fixing generated driver llm_fix_count += 1 dual_logger.log(f'Fixing {target_path} with ' f'{self.builder_runner.fixer_model_name}, ' @@ -403,7 +402,8 @@ def check_target(self, ai_binary, target_path: str) -> Result: try: self._fix_generated_fuzz_target(ai_binary, generated_oss_fuzz_project, target_path, llm_fix_count, - build_result, run_result, dual_logger) + build_result, run_result, dual_logger, + self.benchmark.language) except Exception as e: dual_logger.log('Exception occurred when fixing fuzz target in attempt ' f'{llm_fix_count}: {e}') diff --git a/llm_toolkit/code_fixer.py b/llm_toolkit/code_fixer.py index 8a9b66c97..e40fb8712 100755 --- a/llm_toolkit/code_fixer.py +++ b/llm_toolkit/code_fixer.py @@ -230,14 +230,30 @@ def remove_const_from_png_symbols(content: str) -> str: # ========================= LLM Fixes ========================= # -def extract_error_message(log_path: str, - project_target_basename: str) -> list[str]: +def extract_error_message(log_path: str, project_target_basename: str, + language: str) -> list[str]: """Extracts error message and its context from the file in |log_path|.""" with open(log_path) as log_file: # A more accurate way to extract the error message. log_lines = log_file.readlines() + # Error message extraction for Java projects + if language == 'jvm': + started = False + errors = [] + for log_line in log_lines: + if started: + errors.append(log_line) + if log_line == 'ERROR:__main__:Building fuzzers failed.': + break + else: + if ': error:' in log_line: + errors.append(log_line) + started = True + + return errors + target_name, _ = os.path.splitext(project_target_basename) error_lines_range: list[Optional[int]] = [None, None] @@ -352,7 +368,7 @@ def group_error_messages(error_lines: list[str]) -> list[str]: def llm_fix(ai_binary: str, target_path: str, benchmark: benchmarklib.Benchmark, llm_fix_id: int, error_desc: Optional[str], errors: list[str], - fixer_model_name: str) -> None: + fixer_model_name: str, language: str) -> None: """Reads and fixes |target_path| in place with LLM based on |error_log|.""" fuzz_target_source_code = parser.parse_code(target_path) @@ -368,6 +384,7 @@ def llm_fix(ai_binary: str, target_path: str, benchmark: benchmarklib.Benchmark, errors, prompt_path, response_dir, + language, fixer_model_name, temperature=0.5 - llm_fix_id * 0.04) @@ -409,6 +426,7 @@ def apply_llm_fix(ai_binary: str, errors: list[str], prompt_path: str, response_dir: str, + language: str, fixer_model_name: str = models.DefaultModel.name, temperature: float = 0.4): """Queries LLM to fix the code.""" @@ -419,14 +437,22 @@ def apply_llm_fix(ai_binary: str, temperature=temperature, ) - builder = prompt_builder.DefaultTemplateBuilder(fixer_model) - - context = _collect_context(benchmark, errors) - instruction = _collect_instructions(benchmark, errors, - fuzz_target_source_code) - prompt = builder.build_fixer_prompt(benchmark, fuzz_target_source_code, - error_desc, errors, context, instruction) - prompt.save(prompt_path) + if language == 'jvm': + builder = prompt_builder.JvmErrorFixingBuilder(fixer_model, benchmark, + fuzz_target_source_code, + errors) + prompt = builder.build([], None, None) + prompt.save(prompt_path) + else: + builder = prompt_builder.DefaultTemplateBuilder(fixer_model) + + context = _collect_context(benchmark, errors) + instruction = _collect_instructions(benchmark, errors, + fuzz_target_source_code) + prompt = builder.build_fixer_prompt(benchmark, fuzz_target_source_code, + error_desc, errors, context, + instruction) + prompt.save(prompt_path) fixer_model.query_llm(prompt, response_dir) @@ -665,10 +691,10 @@ def _collect_consume_buffers(fuzz_target_source_code: str) -> str: for buffer_method in ['ConsumeBytes', 'ConsumeData']: if buffer_method in fuzz_target_source_code: instruction += ( - 'IMPORTANT: the harness source code contains a call to ' - f'`{buffer_method}`. Whenever this function is used, you MUST validate' - ' the size of the vector returned, and make sure that the size of the ' - f'vector is equal to argument given to `{buffer_method}`. If it is ' + 'IMPORTANT: the harness source code contains a call to `' + f'{buffer_method}`. Whenever this function is used, you MUST validate' + ' the size of the vector returned, and make sure that the size of the' + f' vector is equal to argument given to `{buffer_method}`. If it is ' 'not equal, the harness should not proceed.\n') instruction += ( f'Furthermore, consider changing {buffer_method} to ' diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py index dfd93f60e..814b0e210 100644 --- a/llm_toolkit/prompt_builder.py +++ b/llm_toolkit/prompt_builder.py @@ -1054,6 +1054,82 @@ def post_process_generated_code(self, generated_code: str) -> str: return generated_code +class JvmErrorFixingBuilder(PromptBuilder): + """Prompt builder for fixing JVM harness with complication error.""" + + def __init__(self, + model: models.LLM, + benchmark: Benchmark, + generated_harness: str, + errors: list[str], + template_dir: str = DEFAULT_TEMPLATE_DIR): + super().__init__(model) + self._template_dir = template_dir + self.benchmark = benchmark + self.generated_harness = generated_harness + self.error_str = '\n'.join(errors) + + # Load templates. + self.template_file = self._find_template( + template_dir, 'jvm_requirement_error_fixing.txt') + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def build(self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it. + Ignore target_file_type, project_example_content + and project_context_content parameters. + """ + with open(self.template_file, 'r') as f: + prompt_text = f.read() + + # Format the repository + target_repository = oss_fuzz_checkout.get_project_repository( + self.benchmark.project) + prompt_text = prompt_text.replace('{TARGET_REPO}', target_repository) + + # Add the generated harness and error string to prompt + prompt_text = prompt_text.replace('{GENERATED_HARNESS}', + self.generated_harness) + prompt_text = prompt_text.replace('{ERRORS}', self.error_str) + + self._prompt.add_priming(prompt_text) + return self._prompt + + def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, + error_desc: Optional[str], + errors: list[str]) -> prompts.Prompt: + """Builds a fixer prompt.""" + # Do nothing for jvm project now. + return self._prompt + + def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, + crash_info: str, crash_func: dict) -> prompts.Prompt: + """Builds a triager prompt.""" + # Do nothing for jvm project now. + return self._prompt + + def post_process_generated_code(self, generated_code: str) -> str: + """Allows prompt builder to adjust the generated code.""" + return generated_code + + class DefaultPythonTemplateBuilder(PromptBuilder): """Default builder for Python projects.""" diff --git a/prompts/template_xml/jvm_requirement_error_fixing.txt b/prompts/template_xml/jvm_requirement_error_fixing.txt new file mode 100644 index 000000000..00c3d1e85 --- /dev/null +++ b/prompts/template_xml/jvm_requirement_error_fixing.txt @@ -0,0 +1,93 @@ +I'm a security engineer looking to convert unit tests into fuzzing harnesses. I got some compilation errors and want you to help fix them. + +The target library is {TARGET_REPO}. + +The target project is written in the Java programming language. + +This is a Java programming language so the harness should be written in Java. +The fuzzing harness should be executable under the Jazzer fuzzing framework. + +I have already done the conversion and here is my harness code. + +{GENERATED_HARNESS} + + +And I got the following errors from the compiler. Please help me fix them while keeping all the format and other logics unchanged. + +{ERRORS} + + +If missing imports for classes used are found but failed to locate the correct import statements, try removing the use of that class. +In your response, include ONLY the code for the harness, nothing more. You should wrap the code in tags. + +Here is an additional list of requirements that you MUST follow. + +NEVER use any methods from the java.lang.Random class in the generated code. +NEVER use any classes or methods in the java.lang.reflect package in the generated code. +NEVER use the @FuzzTest annotation for specifying the fuzzing method. +NEVER use any assert, printing and logging statements in the generated harness. +NEVER use any multithreading or multi-processing approach. +You MUST create the object before calling the target method. +Please use {HARNESS_NAME} as the Java class name. +You MUST invoke the close method of any resource class objects that implements the java.lang.AutoCloseable interface in the finally block after the target method is invoked. +Always create the fuzzing harness from the following templates: + +import com.code_intelligence.jazzer.api.FuzzedDataProvider; +// Other imports + +public class {HARNESS_NAME} { + public static void fuzzerInitialize() { + // Initializing objects for fuzzing + } + + public static void fuzzerTearDown() { + // Tear down objects after fuzzing + } + + public static void fuzzerTestOneInput(FuzzedDataProvider data) { + // Use the FuzzedDataProvider object to generate random data for fuzzing + + // Fuzz by invoking the target method with random parameters / objects generated above. + } +} + + +You MUST ONLY use any of the following methods from the FuzzedDataProvider of the Jazzer framework for generating random data for fuzzing. +If the needed return value is not found in the table, try use constructors or methods to create the needed random object. But you MUST try your best to randomise the random object with the methods in the table. + +| Method | Return Value | +|---------------------------------------------|---------------------------------------| +| `consumeBytes(int length)` | `byte[]` | +| `consumeRemainingAsBytes()` | `byte[]` | +| `consumeString(int length)` | `String` | +| `consumeRemainingAsString()` | `String` | +| `consumeBoolean()` | `boolean` | +| `consumeInt(int min, int max)` | `int` | +| `consumeInt()` | `int` | +| `consumeLong(long min, long max)` | `long` | +| `consumeLong()` | `long` | +| `consumeFloat(float min, float max)` | `float` | +| `consumeFloat()` | `float` | +| `consumeDouble(double min, double max)` | `double` | +| `consumeDouble()` | `double` | +| `consumeChar()` | `char` | +| `consumeChar(char min, char max)` | `char` | +| `consumeShort(short min, short max)` | `short` | +| `consumeShort()` | `short` | +| `consumeRemainingAsCharSequence()` | `CharSequence` | +| `consumeBytestring()` | `byte[]` | +| `consumeBigInteger(int minNumBits)` | `BigInteger` | +| `consumeEnum(Class enumType)` | `E` (Enum type) | +| `consumeProbabilityDouble()` | `double` | +| `consumeFraction()` | `double` | +| `pickValue(T... values)` | `T` (Type of value) | +| `pickValue(List values)` | `T` (Type of value) | +| `consumeByte()` | `byte` | +| `consumeIntList(int length)` | `List` | +| `consumeLongList(int length)` | `List` | +| `consumeFloatList(int length)` | `List` | +| `consumeDoubleList(int length)` | `List` | +| `consumeCharList(int length)` | `List` | + + +