Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support running multiple bash commands and compile in one query #736

Merged
merged 9 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[settings]
src_paths=.
line_length = 80
57 changes: 45 additions & 12 deletions agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import logger
import utils
from llm_toolkit.models import LLM
from llm_toolkit.prompt_builder import DefaultTemplateBuilder
from llm_toolkit.prompts import Prompt
from results import Result
from tool.base_tool import BaseTool
Expand Down Expand Up @@ -63,6 +62,11 @@ def _parse_tag(self, response: str, tag: str) -> str:
match = re.search(rf'<{tag}>(.*?)</{tag}>', response, re.DOTALL)
return match.group(1).strip() if match else ''

def _parse_tags(self, response: str, tag: str) -> list[str]:
"""Parses the XML-style tags from LLM response."""
matches = re.findall(rf'<{tag}>(.*?)</{tag}>', response, re.DOTALL)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future work: Should we experiment with JSON formatting at some point? Let's create an issue if so.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done: #743

return [content.strip() for content in matches]

def _filter_code(self, raw_code_block: str) -> str:
"""Filters out irrelevant lines from |raw_code_block|."""
# TODO(dongge): Move this function to a separate module.
Expand All @@ -74,27 +78,56 @@ def _filter_code(self, raw_code_block: str) -> str:
filtered_code_block = '\n'.join(filtered_lines)
return filtered_code_block

def _format_bash_execution_result(self, process: sp.CompletedProcess) -> str:
def _format_bash_execution_result(
self,
process: sp.CompletedProcess,
previous_prompt: Optional[Prompt] = None) -> str:
"""Formats a prompt based on bash execution result."""
stdout = self.llm.truncate_prompt(process.stdout)
# TODO(dongge) Share input limit evenly if both stdout and stderr overlong.
stderr = self.llm.truncate_prompt(process.stderr, stdout)
if previous_prompt:
previous_prompt_text = previous_prompt.get()
else:
previous_prompt_text = ''
stdout = self.llm.truncate_prompt(process.stdout,
previous_prompt_text).strip()
stderr = self.llm.truncate_prompt(process.stderr,
stdout + previous_prompt_text).strip()
return (f'<bash>\n{process.args}\n</bash>\n'
f'<return code>\n{process.returncode}\n</return code>\n'
f'<stdout>\n{stdout}\n</stdout>\n'
f'<stderr>\n{stderr}\n</stderr>\n')

def _container_handle_bash_command(self, command: str,
tool: BaseTool) -> Prompt:
def _container_handle_bash_command(self, response: str, tool: BaseTool,
prompt: Prompt) -> Prompt:
"""Handles the command from LLM with container |tool|."""
prompt_text = self._format_bash_execution_result(tool.execute(command))
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])

def _container_handle_invalid_tool_usage(self, tool: BaseTool) -> Prompt:
prompt_text = ''
for command in self._parse_tags(response, 'bash'):
prompt_text += self._format_bash_execution_result(
tool.execute(command), previous_prompt=prompt) + '\n'
prompt.append(prompt_text)
return prompt

def _container_handle_invalid_tool_usage(self, tool: BaseTool, cur_round: int,
response: str,
prompt: Prompt) -> Prompt:
"""Formats a prompt to re-teach LLM how to use the |tool|."""
logger.warning('ROUND %02d Invalid response from LLM: %s',
cur_round,
response,
trial=self.trial)
prompt_text = (f'No valid instruction received, Please follow the '
f'interaction protocols:\n{tool.tutorial()}')
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])
prompt.append(prompt_text)
return prompt

def _container_handle_bash_commands(self, response: str, tool: BaseTool,
prompt: Prompt) -> Prompt:
"""Handles the command from LLM with container |tool|."""
prompt_text = ''
for command in self._parse_tags(response, 'bash'):
prompt_text += self._format_bash_execution_result(
tool.execute(command), previous_prompt=prompt) + '\n'
prompt.append(prompt_text)
return prompt

def _sleep_random_duration(
self,
Expand Down
41 changes: 24 additions & 17 deletions agent/prototyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,13 @@ def _validate_fuzz_target_and_build_script_via_compile(
status=compile_succeed and binary_exists,
referenced=function_referenced)

def _container_handle_conclusion(
self, cur_round: int, response: str,
build_result: BuildResult) -> Optional[Prompt]:
def _container_handle_conclusion(self, cur_round: int, response: str,
build_result: BuildResult,
prompt: Prompt) -> Optional[Prompt]:
"""Runs a compilation tool to validate the new fuzz target and build script
from LLM."""
if not self._parse_tag(response, 'fuzz target'):
return prompt
logger.info('----- ROUND %02d Received conclusion -----',
cur_round,
trial=build_result.trial)
Expand All @@ -198,7 +200,8 @@ def _container_handle_conclusion(
return None

if not build_result.compiles:
compile_log = self.llm.truncate_prompt(build_result.compile_log)
compile_log = self.llm.truncate_prompt(build_result.compile_log,
extra_text=prompt.get()).strip()
logger.info('***** Failed to recompile in %02d rounds *****',
cur_round,
trial=build_result.trial)
Expand Down Expand Up @@ -243,25 +246,29 @@ def _container_handle_conclusion(
else:
prompt_text = ''

prompt = DefaultTemplateBuilder(self.llm, initial=prompt_text).build([])
prompt.append(prompt_text)
return prompt

def _container_tool_reaction(self, cur_round: int, response: str,
build_result: BuildResult) -> Optional[Prompt]:
"""Validates LLM conclusion or executes its command."""
# Prioritize Bash instructions.
if command := self._parse_tag(response, 'bash'):
return self._container_handle_bash_command(command, self.inspect_tool)
prompt = DefaultTemplateBuilder(self.llm, None).build([])
prompt = self._container_handle_bash_commands(response, self.inspect_tool,
prompt)

if self._parse_tag(response, 'conclusion'):
return self._container_handle_conclusion(cur_round, response,
build_result)
# Other responses are invalid.
logger.warning('ROUND %02d Invalid response from LLM: %s',
cur_round,
response,
trial=build_result.trial)
return self._container_handle_invalid_tool_usage(self.inspect_tool)
# Then build fuzz target.
prompt = self._container_handle_conclusion(cur_round, response,
build_result, prompt)
if prompt is None:
# Succeeded.
return None

# Finally check invalid responses.
if not prompt.get():
prompt = self._container_handle_invalid_tool_usage(
self.inspect_tool, cur_round, response, prompt)

return prompt

def execute(self, result_history: list[Result]) -> BuildResult:
"""Executes the agent based on previous result."""
Expand Down
89 changes: 89 additions & 0 deletions benchmark-sets/comparison/libdwarf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"functions":
- "name": "dwarf_get_ranges_baseaddress"
"params":
- "name": "dw_dbg"
"type": "bool "
- "name": "dw_die"
"type": "bool "
- "name": "dw_known_base"
"type": "bool "
- "name": "dw_baseaddress"
"type": "bool "
- "name": "dw_at_ranges_offset_present"
"type": "bool "
- "name": "dw_at_ranges_offset"
"type": "bool "
- "name": "dw_error"
"type": "bool "
"return_type": "int"
"signature": "int dwarf_get_ranges_baseaddress(Dwarf_Debug, Dwarf_Die, Dwarf_Bool *, Dwarf_Unsigned *, Dwarf_Bool *, Dwarf_Unsigned *, Dwarf_Error *)"
- "name": "dwarf_init_path_a"
"params":
- "name": "path"
"type": "bool "
- "name": "true_path_out_buffer"
"type": "bool "
- "name": "true_path_bufferlen"
"type": "int"
- "name": "groupnumber"
"type": "int"
- "name": "universalnumber"
"type": "int"
- "name": "errhand"
"type": "bool "
- "name": "errarg"
"type": "bool "
- "name": "ret_dbg"
"type": "bool "
- "name": "error"
"type": "bool "
"return_type": "int"
"signature": "int dwarf_init_path_a(const char *, char *, unsigned int, unsigned int, unsigned int, Dwarf_Handler, Dwarf_Ptr, Dwarf_Debug *, Dwarf_Error *)"
- "name": "dwarf_debug_addr_index_to_addr"
"params":
- "name": "die"
"type": "bool "
- "name": "index"
"type": "size_t"
- "name": "return_addr"
"type": "bool "
- "name": "error"
"type": "bool "
"return_type": "int"
"signature": "int dwarf_debug_addr_index_to_addr(Dwarf_Die, Dwarf_Unsigned, Dwarf_Addr *, Dwarf_Error *)"
- "name": "dwarf_rnglists_get_rle_head"
"params":
- "name": "attr"
"type": "bool "
- "name": "theform"
"type": "short"
- "name": "attr_val"
"type": "size_t"
- "name": "head_out"
"type": "bool "
- "name": "entries_count_out"
"type": "bool "
- "name": "global_offset_of_rle_set"
"type": "bool "
- "name": "error"
"type": "bool "
"return_type": "int"
"signature": "int dwarf_rnglists_get_rle_head(Dwarf_Attribute, Dwarf_Half, Dwarf_Unsigned, Dwarf_Rnglists_Head *, Dwarf_Unsigned *, Dwarf_Unsigned *, Dwarf_Error *)"
- "name": "dwarf_find_die_given_sig8"
"params":
- "name": "dbg"
"type": "bool "
- "name": "ref"
"type": "bool "
- "name": "die_out"
"type": "bool "
- "name": "is_info"
"type": "bool "
- "name": "error"
"type": "bool "
"return_type": "int"
"signature": "int dwarf_find_die_given_sig8(Dwarf_Debug, Dwarf_Sig8 *, Dwarf_Die *, Dwarf_Bool *, Dwarf_Error *)"
"language": "c"
"project": "libdwarf"
"target_name": "fuzz_crc"
"target_path": "/src/libdwarf/fuzz/fuzz_crc.c"
24 changes: 15 additions & 9 deletions llm_toolkit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
import openai
import tiktoken
import vertexai
from google.api_core.exceptions import (GoogleAPICallError, InvalidArgument,
ResourceExhausted, ServiceUnavailable,
TooManyRequests)
from google.api_core.exceptions import (GoogleAPICallError, InternalServerError,
InvalidArgument, ResourceExhausted,
ServiceUnavailable, TooManyRequests)
from vertexai import generative_models
from vertexai.preview.generative_models import ChatSession, GenerativeModel
from vertexai.preview.language_models import CodeGenerationModel
Expand Down Expand Up @@ -649,6 +649,7 @@ def get_chat_client(self, model: GenerativeModel) -> Any:
InvalidArgument,
ValueError, # TODO(dongge): Handle RECITATION specifically.
IndexError, # A known error from vertexai.
InternalServerError,
],
other_exceptions={
ResourceExhausted: 100,
Expand Down Expand Up @@ -677,17 +678,22 @@ def truncate_prompt(self,

extra_text_token_count = self.estimate_token_num(extra_text)
# Reserve 10000 tokens for raw prompt wrappers.
max_raw_prompt_token_size = (self.MAX_INPUT_TOKEN - extra_text_token_count -
10000)

max_raw_prompt_token_size = (self.MAX_INPUT_TOKEN * 0.9 -
extra_text_token_count) // 4
while token_count > max_raw_prompt_token_size:
estimate_truncate_size = int(
(1 - max_raw_prompt_token_size / token_count) * len(raw_prompt_text))
raw_prompt_text = raw_prompt_text[estimate_truncate_size + 1:]

num_init_tokens = min(100, int(max_raw_prompt_token_size * 0.1))
raw_prompt_init = raw_prompt_text[:num_init_tokens] + (
'\n...(truncated due to exceeding input token limit)...\n')
raw_prompt_text = raw_prompt_init + raw_prompt_text[
min(num_init_tokens + estimate_truncate_size + 1,
len(raw_prompt_text) - 100):]

token_count = self.estimate_token_num(raw_prompt_text)
logger.warning('Truncated raw prompt from %d to %d tokens:',
original_token_count, token_count)
logger.warning('Truncated %d raw prompt chars from %d to %d tokens.',
estimate_truncate_size, original_token_count, token_count)

return raw_prompt_text

Expand Down
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def deserialize_from_dill(dill_path: Any) -> Any:


def _default_retry_delay_fn(e: Exception, n: int):
"""Delays retry by a random seconds between 0 to 1 minute."""
"""Delays retry by a random seconds between 1 to 2 minutes."""
del e, n
return random.uniform(0, 60)
return random.uniform(60, 120)


def retryable(exceptions=None,
Expand Down
Loading