Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Arthur Chan <[email protected]>
  • Loading branch information
arthurscchan committed Dec 3, 2024
1 parent 8ef8ecc commit b619bff
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 21 deletions.
14 changes: 7 additions & 7 deletions llm_toolkit/code_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,7 @@ def remove_const_from_png_symbols(content: str) -> str:
# ========================= LLM Fixes ========================= #


def extract_error_message(log_path: str,
project_target_basename: str,
def extract_error_message(log_path: str, project_target_basename: str,
language: str) -> list[str]:
"""Extracts error message and its context from the file in |log_path|."""

Expand Down Expand Up @@ -451,7 +450,8 @@ def apply_llm_fix(ai_binary: str,
instruction = _collect_instructions(benchmark, errors,
fuzz_target_source_code)
prompt = builder.build_fixer_prompt(benchmark, fuzz_target_source_code,
error_desc, errors, context, instruction)
error_desc, errors, context,
instruction)
prompt.save(prompt_path)

fixer_model.query_llm(prompt, response_dir)
Expand Down Expand Up @@ -691,10 +691,10 @@ def _collect_consume_buffers(fuzz_target_source_code: str) -> str:
for buffer_method in ['ConsumeBytes', 'ConsumeData']:
if buffer_method in fuzz_target_source_code:
instruction += (
'IMPORTANT: the harness source code contains a call to '
f'`{buffer_method}`. Whenever this function is used, you MUST validate'
' the size of the vector returned, and make sure that the size of the '
f'vector is equal to argument given to `{buffer_method}`. If it is '
'IMPORTANT: the harness source code contains a call to `'
f'{buffer_method}`. Whenever this function is used, you MUST validate'
' the size of the vector returned, and make sure that the size of the'
f' vector is equal to argument given to `{buffer_method}`. If it is '
'not equal, the harness should not proceed.\n')
instruction += (
f'Furthermore, consider changing {buffer_method} to '
Expand Down
7 changes: 3 additions & 4 deletions llm_toolkit/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,8 +1070,8 @@ def __init__(self,
self.error_str = '\n'.join(errors)

# Load templates.
self.template_file = self._find_template(template_dir,
'jvm_requirement_error_fixing.txt')
self.template_file = self._find_template(
template_dir, 'jvm_requirement_error_fixing.txt')

def _find_template(self, template_dir: str, template_name: str) -> str:
"""Finds template file based on |template_dir|."""
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def build(self,

# Add the generated harness and error string to prompt
prompt_text = prompt_text.replace('{GENERATED_HARNESS}',
self.generated_harness)
self.generated_harness)
prompt_text = prompt_text.replace('{ERRORS}', self.error_str)

self._prompt.add_priming(prompt_text)
Expand Down Expand Up @@ -1595,4 +1595,3 @@ def post_process_generated_code(self, generated_code: str) -> str:
generated_code = generated_code.replace(
'extern "C" int LLVMFuzzerTestOneInput', 'int LLVMFuzzerTestOneInput')
return generated_code

13 changes: 3 additions & 10 deletions run_one_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ def generate_targets_for_analysis(
use_context: bool,
example_pair: list[list[str]],
prompt_builder_to_use: str = 'DEFAULT',
cloud_experiment_bucket: str = '',
jvm_error_fix: tuple[str, str] = None) -> List[str]:
cloud_experiment_bucket: str = '') -> List[str]:
"""Generates a set of harnesses and build scripts ready to be evaluated
by `check_targets`. This is where the core first LLM logic is used to
generate harnesses.
Expand All @@ -287,14 +286,8 @@ def generate_targets_for_analysis(
else:
context_info = {}

if jvm_error_fix:
# If jvm error fix tuple is provided, use error fixing prompt builder.
logger.info('Fixing generated JVM harness')
code, error = jvm_error_fix
builder = prompt_builder.JvmErrorFixingBuilder(model, benchmark,
template_dir, code, error)
elif benchmark.test_file_path:
# If this is a test benchmark then we will use a test prompt builder.
# If this is a test benchmark then we will use a test prompt builder.
if benchmark.test_file_path:
logging.info('Generating a target for test case: %s',
benchmark.test_file_path)
builder = prompt_builder.TestToHarnessConverter(model, benchmark,
Expand Down

0 comments on commit b619bff

Please sign in to comment.