Skip to content

Commit

Permalink
[JVM] Add retry logic for JVM (#740)
Browse files Browse the repository at this point in the history
This PR add back the retry logic for JVM when the first generated
harness is failed to build.

---------

Signed-off-by: Arthur Chan <[email protected]>
  • Loading branch information
arthurscchan authored Dec 4, 2024
1 parent 4b8249b commit 66fff58
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 24 deletions.
4 changes: 2 additions & 2 deletions experiment/builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def build_and_run_local(

if not build_result.succeeded:
errors = code_fixer.extract_error_message(benchmark_log_path,
project_target_name)
project_target_name, language)
build_result.errors = errors
return build_result, None

Expand Down Expand Up @@ -980,7 +980,7 @@ def build_and_run_cloud(
if not build_result.succeeded:
errors = code_fixer.extract_error_message(
self.work_dirs.build_logs_target(generated_target_name, iteration),
os.path.basename(self.benchmark.target_path))
os.path.basename(self.benchmark.target_path), language)
build_result.errors = errors
logger.info('Cloud evaluation of %s indicates a failure: %s',
os.path.realpath(target_path), errors)
Expand Down
14 changes: 7 additions & 7 deletions experiment/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,9 @@ def _fix_generated_fuzz_target(self, ai_binary: str,
target_path: str, iteration: int,
build_result: BuildResult,
run_result: Optional[RunResult],
dual_logger: _Logger):
dual_logger: _Logger, language: str):
"""Fixes the generated fuzz target."""
if build_result.succeeded:
if build_result.succeeded and not language == 'jvm':
if run_result:
error_desc, errors = run_result.semantic_check.get_error_info()
else:
Expand All @@ -293,7 +293,8 @@ def _fix_generated_fuzz_target(self, ai_binary: str,
else:
error_desc, errors = None, build_result.errors
code_fixer.llm_fix(ai_binary, target_path, self.benchmark, iteration,
error_desc, errors, self.builder_runner.fixer_model_name)
error_desc, errors, self.builder_runner.fixer_model_name,
language)
shutil.copyfile(
target_path,
os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects',
Expand Down Expand Up @@ -393,17 +394,16 @@ def check_target(self, ai_binary, target_path: str) -> Result:
# Exit cond 2: fix limit is reached.
break

# 2. Fixing generated driver. Skipped for jvm projects.
if self.benchmark.language == 'jvm':
break
# 2. Fixing generated driver
llm_fix_count += 1
dual_logger.log(f'Fixing {target_path} with '
f'{self.builder_runner.fixer_model_name}, '
f'attempt {llm_fix_count}.')
try:
self._fix_generated_fuzz_target(ai_binary, generated_oss_fuzz_project,
target_path, llm_fix_count,
build_result, run_result, dual_logger)
build_result, run_result, dual_logger,
self.benchmark.language)
except Exception as e:
dual_logger.log('Exception occurred when fixing fuzz target in attempt '
f'{llm_fix_count}: {e}')
Expand Down
56 changes: 41 additions & 15 deletions llm_toolkit/code_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,30 @@ def remove_const_from_png_symbols(content: str) -> str:
# ========================= LLM Fixes ========================= #


def extract_error_message(log_path: str,
project_target_basename: str) -> list[str]:
def extract_error_message(log_path: str, project_target_basename: str,
language: str) -> list[str]:
"""Extracts error message and its context from the file in |log_path|."""

with open(log_path) as log_file:
# A more accurate way to extract the error message.
log_lines = log_file.readlines()

# Error message extraction for Java projects
if language == 'jvm':
started = False
errors = []
for log_line in log_lines:
if started:
errors.append(log_line)
if log_line == 'ERROR:__main__:Building fuzzers failed.':
break
else:
if ': error:' in log_line:
errors.append(log_line)
started = True

return errors

target_name, _ = os.path.splitext(project_target_basename)

error_lines_range: list[Optional[int]] = [None, None]
Expand Down Expand Up @@ -352,7 +368,7 @@ def group_error_messages(error_lines: list[str]) -> list[str]:

def llm_fix(ai_binary: str, target_path: str, benchmark: benchmarklib.Benchmark,
llm_fix_id: int, error_desc: Optional[str], errors: list[str],
fixer_model_name: str) -> None:
fixer_model_name: str, language: str) -> None:
"""Reads and fixes |target_path| in place with LLM based on |error_log|."""
fuzz_target_source_code = parser.parse_code(target_path)

Expand All @@ -368,6 +384,7 @@ def llm_fix(ai_binary: str, target_path: str, benchmark: benchmarklib.Benchmark,
errors,
prompt_path,
response_dir,
language,
fixer_model_name,
temperature=0.5 - llm_fix_id * 0.04)

Expand Down Expand Up @@ -409,6 +426,7 @@ def apply_llm_fix(ai_binary: str,
errors: list[str],
prompt_path: str,
response_dir: str,
language: str,
fixer_model_name: str = models.DefaultModel.name,
temperature: float = 0.4):
"""Queries LLM to fix the code."""
Expand All @@ -419,14 +437,22 @@ def apply_llm_fix(ai_binary: str,
temperature=temperature,
)

builder = prompt_builder.DefaultTemplateBuilder(fixer_model)

context = _collect_context(benchmark, errors)
instruction = _collect_instructions(benchmark, errors,
fuzz_target_source_code)
prompt = builder.build_fixer_prompt(benchmark, fuzz_target_source_code,
error_desc, errors, context, instruction)
prompt.save(prompt_path)
if language == 'jvm':
builder = prompt_builder.JvmErrorFixingBuilder(fixer_model, benchmark,
fuzz_target_source_code,
errors)
prompt = builder.build([], None, None)
prompt.save(prompt_path)
else:
builder = prompt_builder.DefaultTemplateBuilder(fixer_model)

context = _collect_context(benchmark, errors)
instruction = _collect_instructions(benchmark, errors,
fuzz_target_source_code)
prompt = builder.build_fixer_prompt(benchmark, fuzz_target_source_code,
error_desc, errors, context,
instruction)
prompt.save(prompt_path)

fixer_model.query_llm(prompt, response_dir)

Expand Down Expand Up @@ -665,10 +691,10 @@ def _collect_consume_buffers(fuzz_target_source_code: str) -> str:
for buffer_method in ['ConsumeBytes', 'ConsumeData']:
if buffer_method in fuzz_target_source_code:
instruction += (
'IMPORTANT: the harness source code contains a call to '
f'`{buffer_method}`. Whenever this function is used, you MUST validate'
' the size of the vector returned, and make sure that the size of the '
f'vector is equal to argument given to `{buffer_method}`. If it is '
'IMPORTANT: the harness source code contains a call to `'
f'{buffer_method}`. Whenever this function is used, you MUST validate'
' the size of the vector returned, and make sure that the size of the'
f' vector is equal to argument given to `{buffer_method}`. If it is '
'not equal, the harness should not proceed.\n')
instruction += (
f'Furthermore, consider changing {buffer_method} to '
Expand Down
76 changes: 76 additions & 0 deletions llm_toolkit/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,82 @@ def post_process_generated_code(self, generated_code: str) -> str:
return generated_code


class JvmErrorFixingBuilder(PromptBuilder):
"""Prompt builder for fixing JVM harness with complication error."""

def __init__(self,
model: models.LLM,
benchmark: Benchmark,
generated_harness: str,
errors: list[str],
template_dir: str = DEFAULT_TEMPLATE_DIR):
super().__init__(model)
self._template_dir = template_dir
self.benchmark = benchmark
self.generated_harness = generated_harness
self.error_str = '\n'.join(errors)

# Load templates.
self.template_file = self._find_template(
template_dir, 'jvm_requirement_error_fixing.txt')

def _find_template(self, template_dir: str, template_name: str) -> str:
"""Finds template file based on |template_dir|."""
preferred_template = os.path.join(template_dir, template_name)
# Use the preferred template if it exists.
if os.path.isfile(preferred_template):
return preferred_template
# Fall back to the default template.
default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name)
return default_template

def _get_template(self, template_file: str) -> str:
"""Reads the template for prompts."""
with open(template_file) as file:
return file.read()

def build(self,
example_pair: list[list[str]],
project_example_content: Optional[list[list[str]]] = None,
project_context_content: Optional[dict] = None) -> prompts.Prompt:
"""Constructs a prompt using the templates in |self| and saves it.
Ignore target_file_type, project_example_content
and project_context_content parameters.
"""
with open(self.template_file, 'r') as f:
prompt_text = f.read()

# Format the repository
target_repository = oss_fuzz_checkout.get_project_repository(
self.benchmark.project)
prompt_text = prompt_text.replace('{TARGET_REPO}', target_repository)

# Add the generated harness and error string to prompt
prompt_text = prompt_text.replace('{GENERATED_HARNESS}',
self.generated_harness)
prompt_text = prompt_text.replace('{ERRORS}', self.error_str)

self._prompt.add_priming(prompt_text)
return self._prompt

def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str,
error_desc: Optional[str],
errors: list[str]) -> prompts.Prompt:
"""Builds a fixer prompt."""
# Do nothing for jvm project now.
return self._prompt

def build_triager_prompt(self, benchmark: Benchmark, driver_code: str,
crash_info: str, crash_func: dict) -> prompts.Prompt:
"""Builds a triager prompt."""
# Do nothing for jvm project now.
return self._prompt

def post_process_generated_code(self, generated_code: str) -> str:
"""Allows prompt builder to adjust the generated code."""
return generated_code


class DefaultPythonTemplateBuilder(PromptBuilder):
"""Default builder for Python projects."""

Expand Down
93 changes: 93 additions & 0 deletions prompts/template_xml/jvm_requirement_error_fixing.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
I'm a security engineer looking to convert unit tests into fuzzing harnesses. I got some compilation errors and want you to help fix them.

The target library is {TARGET_REPO}.

The target project is written in the Java programming language.

This is a Java programming language so the harness should be written in Java.
The fuzzing harness should be executable under the Jazzer fuzzing framework.

I have already done the conversion and here is my harness code.
<code>
{GENERATED_HARNESS}
</code>

And I got the following errors from the compiler. Please help me fix them while keeping all the format and other logics unchanged.
<error>
{ERRORS}
</error>

If missing imports for classes used are found but failed to locate the correct import statements, try removing the use of that class.
In your response, include ONLY the code for the harness, nothing more. You should wrap the code in <code></code> tags.

Here is an additional list of requirements that you MUST follow.
<requirements>
<item>NEVER use any methods from the <code>java.lang.Random</code> class in the generated code.</item>
<item>NEVER use any classes or methods in the <code>java.lang.reflect</code> package in the generated code.</item>
<item>NEVER use the @FuzzTest annotation for specifying the fuzzing method.</item>
<item>NEVER use any assert, printing and logging statements in the generated harness.</item>
<item>NEVER use any multithreading or multi-processing approach.</item>
<item>You MUST create the object before calling the target method.</item>
<item>Please use {HARNESS_NAME} as the Java class name.</item>
<item>You MUST invoke the close method of any resource class objects that implements the java.lang.AutoCloseable interface in the finally block after the target method is invoked.</item>
<item>Always create the fuzzing harness from the following templates:
<code>
import com.code_intelligence.jazzer.api.FuzzedDataProvider;
// Other imports

public class {HARNESS_NAME} {
public static void fuzzerInitialize() {
// Initializing objects for fuzzing
}

public static void fuzzerTearDown() {
// Tear down objects after fuzzing
}

public static void fuzzerTestOneInput(FuzzedDataProvider data) {
// Use the FuzzedDataProvider object to generate random data for fuzzing

// Fuzz by invoking the target method with random parameters / objects generated above.
}
}
</code></item>
<item>
You MUST ONLY use any of the following methods from the FuzzedDataProvider of the Jazzer framework for generating random data for fuzzing.
If the needed return value is not found in the table, try use constructors or methods to create the needed random object. But you MUST try your best to randomise the random object with the methods in the table.

| Method | Return Value |
|---------------------------------------------|---------------------------------------|
| `consumeBytes(int length)` | `byte[]` |
| `consumeRemainingAsBytes()` | `byte[]` |
| `consumeString(int length)` | `String` |
| `consumeRemainingAsString()` | `String` |
| `consumeBoolean()` | `boolean` |
| `consumeInt(int min, int max)` | `int` |
| `consumeInt()` | `int` |
| `consumeLong(long min, long max)` | `long` |
| `consumeLong()` | `long` |
| `consumeFloat(float min, float max)` | `float` |
| `consumeFloat()` | `float` |
| `consumeDouble(double min, double max)` | `double` |
| `consumeDouble()` | `double` |
| `consumeChar()` | `char` |
| `consumeChar(char min, char max)` | `char` |
| `consumeShort(short min, short max)` | `short` |
| `consumeShort()` | `short` |
| `consumeRemainingAsCharSequence()` | `CharSequence` |
| `consumeBytestring()` | `byte[]` |
| `consumeBigInteger(int minNumBits)` | `BigInteger` |
| `consumeEnum(Class<E> enumType)` | `E` (Enum type) |
| `consumeProbabilityDouble()` | `double` |
| `consumeFraction()` | `double` |
| `pickValue(T... values)` | `T` (Type of value) |
| `pickValue(List<T> values)` | `T` (Type of value) |
| `consumeByte()` | `byte` |
| `consumeIntList(int length)` | `List<Integer>` |
| `consumeLongList(int length)` | `List<Long>` |
| `consumeFloatList(int length)` | `List<Float>` |
| `consumeDoubleList(int length)` | `List<Double>` |
| `consumeCharList(int length)` | `List<Character>` |

</item>
</requirements>

0 comments on commit 66fff58

Please sign in to comment.