diff --git a/agent/base_agent.py b/agent/base_agent.py
index 97325be35..12d593b8a 100644
--- a/agent/base_agent.py
+++ b/agent/base_agent.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""The abstract base class for LLM agents in stages."""
import argparse
+import asyncio
import os
import random
import re
@@ -23,11 +24,14 @@
from typing import Any, Optional
import requests
+from google.adk import agents, runners, sessions
+from google.genai import errors, types
import logger
import utils
from data_prep import introspector
-from llm_toolkit.models import LLM
+from experiment import benchmark as benchmarklib
+from llm_toolkit.models import LLM, VertexAIModel
from llm_toolkit.prompts import Prompt
from results import Result
from tool.base_tool import BaseTool
@@ -295,6 +299,107 @@ def execute(self, result_history: list[Result]) -> Result:
"""Executes the agent based on previous result."""
+class ADKBaseAgent(BaseAgent):
+ """The abstract base class for agents created using the ADK library."""
+
+ def __init__(self,
+ trial: int,
+ llm: LLM,
+ args: argparse.Namespace,
+ benchmark: benchmarklib.Benchmark,
+ description: str = '',
+ instruction: str = '',
+ tools: Optional[list] = None,
+ name: str = ''):
+
+ super().__init__(trial, llm, args, tools, name)
+
+ self.benchmark = benchmark
+
+ # For now, ADKBaseAgents only support the Vertex AI Models.
+ if not isinstance(llm, VertexAIModel):
+ raise ValueError(f'{self.name} only supports Vertex AI models.')
+
+ # Create the agent using the ADK library
+ adk_agent = agents.LlmAgent(
+ name=self.name,
+ model=llm._vertex_ai_model,
+ description=description,
+ instruction=instruction,
+ tools=tools or [],
+ )
+
+ # Create the session service
+ session_service = sessions.InMemorySessionService()
+ session_service.create_session(
+ app_name=self.name,
+ user_id=benchmark.id,
+ session_id=f'session_{self.trial}',
+ )
+
+ # Create the runner
+ self.runner = runners.Runner(
+ agent=adk_agent,
+ app_name=self.name,
+ session_service=session_service,
+ )
+
+ self.round = 0
+
+ logger.info('ADK Agent %s created.', self.name, trial=self.trial)
+
+ def chat_llm(self, cur_round: int, client: Any, prompt: Prompt,
+ trial: int) -> str:
+ """Call the agent with the given prompt, running async code in sync."""
+
+ self.round = cur_round
+
+ self.log_llm_prompt(prompt.get())
+
+ async def _call():
+ user_id = self.benchmark.id
+ session_id = f"session_{self.trial}"
+ content = types.Content(role='user',
+ parts=[types.Part(text=prompt.get())])
+
+ final_response_text = ''
+
+ async for event in self.runner.run_async(
+ user_id=user_id,
+ session_id=session_id,
+ new_message=content,
+ ):
+ if event.is_final_response():
+ if (event.content and event.content.parts and
+ event.content.parts[0].text):
+ final_response_text = event.content.parts[0].text
+ elif event.actions and event.actions.escalate:
+ error_message = event.error_message
+ logger.error('Agent escalated: %s', error_message, trial=self.trial)
+
+ self.log_llm_response(final_response_text)
+
+ return final_response_text
+
+ return self.llm.with_retry_on_error(lambda: asyncio.run(_call()),
+ [errors.ClientError])
+
+ def log_llm_prompt(self, promt: str) -> None:
+ self.round += 1
+ logger.info('%s',
+ self.round,
+ promt,
+ self.round,
+ trial=self.trial)
+
+ def log_llm_response(self, response: str) -> None:
+ logger.info('%s',
+ self.round,
+ response,
+ self.round,
+ trial=self.trial)
+
+
if __name__ == "__main__":
# For cloud experiments.
BaseAgent.cloud_main()
diff --git a/agent/function_analyzer.py b/agent/function_analyzer.py
index fbed7310b..ef538b954 100644
--- a/agent/function_analyzer.py
+++ b/agent/function_analyzer.py
@@ -18,23 +18,20 @@
"""
import argparse
-import asyncio
import os
from typing import Optional
-from google.adk import agents, runners, sessions
-from google.genai import types
-
import logger
import results as resultslib
from agent import base_agent
+from data_prep import introspector
from experiment import benchmark as benchmarklib
from experiment.workdir import WorkDirs
from llm_toolkit import models, prompt_builder, prompts
-from tool import base_tool, fuzz_introspector_tool
+from tool import container_tool
-class FunctionAnalyzer(base_agent.BaseAgent):
+class FunctionAnalyzer(base_agent.ADKBaseAgent):
"""An LLM agent to analyze a function and identify its implicit requirements.
The results of this analysis will be used by the writer agents to
generate correct fuzz target for the function.
@@ -44,117 +41,52 @@ def __init__(self,
trial: int,
llm: models.LLM,
args: argparse.Namespace,
- benchmark: benchmarklib.Benchmark,
- tools: Optional[list[base_tool.BaseTool]] = None,
+ benchmark: benchmarklib.Benchmark,\
name: str = ''):
- # Ensure the llm is an instance of VertexAIModel
- # TODO (pamusuo): Provide support for other LLM models
- if not isinstance(llm, models.VertexAIModel):
- raise ValueError(
- "FunctionAnalyzer agent requires a VertexAIModel instance for llm.")
-
- super().__init__(trial, llm, args, tools, name)
-
- self.vertex_ai_model = llm._vertex_ai_model
- self.benchmark = benchmark
-
- self.initialize()
-
- def initialize(self):
- """Initialize the function analyzer agent with the given benchmark."""
-
- # Initialize the Fuzz Introspector tool
- introspector_tool = fuzz_introspector_tool.FuzzIntrospectorTool(
- self.benchmark, self.name)
-
- # Create the agent using the ADK library
- # TODO(pamusuo): Create another AdkBaseAgent that extends
- # BaseAgent and initializes an ADK agent as well.
- function_analyzer = agents.LlmAgent(
- name="FunctionAnalyzer",
- model=self.vertex_ai_model,
- description="""Extracts a function's requirements
- from its source implementation.""",
- instruction=
- """You are a security engineer tasked with analyzing a function
- and extracting its input requirements,
- necessary for it to execute correctly.""",
- tools=[introspector_tool.function_source_with_name],
- )
-
- # Create the session service
- session_service = sessions.InMemorySessionService()
- session_service.create_session(
- app_name=self.name,
- user_id=self.benchmark.id,
- session_id=f"session_{self.trial}",
- )
-
- # Create the runner
- self.runner = runners.Runner(
- agent=function_analyzer,
- app_name=self.name,
- session_service=session_service,
- )
-
- logger.info("Function Analyzer Agent created, with name: %s",
- self.name,
- trial=self.trial)
-
- async def call_agent(self, query: str, runner: runners.Runner, user_id: str,
- session_id: str) -> str:
- """Call the agent asynchronously with the given query."""
-
- content = types.Content(role='user', parts=[types.Part(text=query)])
-
- final_response_text = ''
-
- result_available = False
+ builder = prompt_builder.FunctionAnalyzerTemplateBuilder(llm, benchmark)
- async for event in runner.run_async(
- user_id=user_id,
- session_id=session_id,
- new_message=content,
- ):
+ description = builder.get_description().get()
- if event.is_final_response():
- if (event.content and event.content.parts and
- event.content.parts[0].text):
- final_response_text = event.content.parts[0].text
- result_available = True
- elif event.actions and event.actions.escalate:
- error_message = event.error_message
- logger.error("Agent escalated: %s", error_message, trial=self.trial)
+ instruction = builder.get_instruction().get()
- logger.info("<<< Agent response: %s", final_response_text, trial=self.trial)
+ tools = [self.get_function_implementation, self.search_project_files]
- if result_available and self._parse_tag(final_response_text, 'response'):
- # Get the requirements from the response
- result_str = self._parse_tag(final_response_text, 'response')
- else:
- result_str = ''
+ super().__init__(trial, llm, args, benchmark, description, instruction,
+ tools, name)
- return result_str
+ self.project_functions = None
def write_requirements_to_file(self, args, requirements: str) -> str:
"""Write the requirements to a file."""
if not requirements:
- logger.warning("No requirements to write to file.", trial=self.trial)
+ logger.warning('No requirements to write to file.', trial=self.trial)
return ''
requirement_path = os.path.join(args.work_dirs.requirements,
- f"{self.benchmark.id}.txt")
+ f'{self.benchmark.id}.txt')
with open(requirement_path, 'w') as f:
f.write(requirements)
- logger.info("Requirements written to %s",
+ logger.info('Requirements written to %s',
requirement_path,
trial=self.trial)
return requirement_path
+ def handle_llm_response(self, final_response_text: str,
+ result: resultslib.Result) -> None:
+ """Handle the LLM response and update the result."""
+
+ result_str = self._parse_tag(final_response_text, 'response')
+ requirements = self._parse_tag(result_str, 'requirements')
+ if requirements:
+ # Write the requirements to a file
+ requirement_path = self.write_requirements_to_file(self.args, result_str)
+ function_analysis = resultslib.FunctionAnalysisResult(requirement_path)
+ result.function_analysis = function_analysis
+
def execute(self,
result_history: list[resultslib.Result]) -> resultslib.Result:
"""Execute the agent with the given results."""
@@ -168,29 +100,21 @@ def execute(self,
work_dirs=self.args.work_dirs,
)
+ # Initialize the ProjectContainerTool for local file search
+ self.inspect_tool = container_tool.ProjectContainerTool(self.benchmark)
+ self.inspect_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null')
+
# Call the agent asynchronously and return the result
prompt = self._initial_prompt(result_history)
- query = prompt.gettext()
- # Validate query is not empty
- if not query.strip():
- logger.error(
- "Error occurred while building initial prompt. Cannot call the agent.",
- trial=self.trial)
- return result
-
- logger.info("Initial prompt created. Calling LLM...", trial=self.trial)
+ final_response_text = self.chat_llm(self.round,
+ client=None,
+ prompt=prompt,
+ trial=result_history[-1].trial)
- user_id = self.benchmark.id
- session_id = f"session_{self.trial}"
- result_str = asyncio.run(
- self.call_agent(query, self.runner, user_id, session_id))
+ self.handle_llm_response(final_response_text, result)
- if result_str:
- # Write the requirements to a file
- requirement_path = self.write_requirements_to_file(self.args, result_str)
- function_analysis = resultslib.FunctionAnalysisResult(requirement_path)
- result.function_analysis = function_analysis
+ self.inspect_tool.terminate()
return result
@@ -205,4 +129,105 @@ def _initial_prompt(
prompt = builder.build_prompt()
+ prompt.append(self.inspect_tool.tutorial())
+
return prompt
+
+ def search_project_files(self, request: str) -> str:
+ """
+ This function tool uses bash commands to search the project's source files,
+ and retrieve requested code snippets or file contents.
+ Args:
+ request (str): The bash command to execute and its justification,
+ formatted using the and tags.
+ Returns:
+ str: The response from executing the bash commands,
+ formatted using the , and tags.
+ """
+
+ self.log_llm_response(request)
+
+ prompt = prompt_builder.DefaultTemplateBuilder(self.llm, None).build([])
+
+ if request:
+ prompt = self._container_handle_bash_commands(request, self.inspect_tool,
+ prompt)
+
+ # Finally check invalid request.
+ if not request or not prompt.get():
+ prompt = self._container_handle_invalid_tool_usage(
+ self.inspect_tool, 0, request, prompt)
+
+ tool_response = prompt.get()
+
+ self.log_llm_prompt(tool_response)
+
+ return tool_response
+
+ def get_function_implementation(self, project_name: str,
+ function_name: str) -> str:
+ """
+ Retrieves a function's source from the fuzz introspector API,
+ using the project's name and function's name
+
+ Args:
+ project_name (str): The name of the project.
+ function_name (str): The name of the function.
+
+ Returns:
+ str: Source code of the function if found, otherwise an empty string.
+ """
+ request = f"""
+ Requesting implementation for the function:
+ Function name: {function_name}
+ Project name: {project_name}
+ """
+
+ self.log_llm_response(request)
+
+ if self.project_functions is None:
+ logger.info(
+ 'Project functions not initialized. Initializing for project "%s".',
+ project_name,
+ trial=self.trial)
+ functions_list = introspector.query_introspector_all_functions(
+ project_name)
+
+ if functions_list:
+ self.project_functions = {
+ func['debug_summary']['name']: func
+ for func in functions_list
+ if isinstance(func.get('debug_summary'), dict) and
+ isinstance(func['debug_summary'].get('name'), str) and
+ func['debug_summary']['name'].strip()
+ }
+ else:
+ self.project_functions = None
+
+ response = f"""
+ Project name: {project_name}
+ Function name: {function_name}
+ """
+ function_source = ''
+
+ if self.project_functions:
+ function_dict = self.project_functions.get(function_name, {})
+ function_signature = function_dict.get('function_signature', '')
+
+ function_source = introspector.query_introspector_function_source(
+ project_name, function_signature)
+
+ if function_source.strip():
+ response += f"""
+ Function source code:
+ {function_source}
+ """
+ else:
+ logger.error('Error: Function with name "%s" not found in project "%s".',
+ function_name,
+ project_name,
+ trial=self.trial)
+
+ self.log_llm_prompt(response)
+
+ return response
diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py
index 2fed2d76a..b9fcf5d40 100644
--- a/llm_toolkit/prompt_builder.py
+++ b/llm_toolkit/prompt_builder.py
@@ -901,36 +901,34 @@ def __init__(self,
super().__init__(model, benchmark, template_dir, initial)
# Load templates.
- self.function_analyzer_instruction_template_file = self._find_template(
+ self.function_analyzer_instruction_file = self._find_template(
AGENT_TEMPLATE_DIR, 'function-analyzer-instruction.txt')
- self.context_retrieve_template_file = self._find_template(
- AGENT_TEMPLATE_DIR, 'context-retriever-instruction.txt')
+ self.function_analyzer_description_file = self._find_template(
+ AGENT_TEMPLATE_DIR, 'function-analyzer-description.txt')
self.function_analyzer_prompt_template_file = self._find_template(
AGENT_TEMPLATE_DIR, 'function-analyzer-priming.txt')
- def build_instruction(self) -> prompts.Prompt:
+ def get_instruction(self) -> prompts.Prompt:
"""Constructs a prompt using the templates in |self| and saves it."""
self._prompt = self._model.prompt_type()(None)
if not self.benchmark:
return self._prompt
- prompt = self._get_template(
- self.function_analyzer_instruction_template_file)
+ prompt = self._get_template(self.function_analyzer_instruction_file)
self._prompt.append(prompt)
return self._prompt
- def build_context_retriever_instruction(self) -> prompts.Prompt:
+ def get_description(self) -> prompts.Prompt:
"""Constructs a prompt using the templates in |self| and saves it."""
self._prompt = self._model.prompt_type()(None)
-
if not self.benchmark:
return self._prompt
- prompt = self._get_template(self.context_retrieve_template_file)
+ prompt = self._get_template(self.function_analyzer_description_file)
self._prompt.append(prompt)
@@ -957,7 +955,6 @@ def build_prompt(self) -> prompts.Prompt:
if not func_source:
logger.error('No function source found for project: %s, function: %s',
self.benchmark.project, self.benchmark.function_signature)
- return prompts.TextPrompt()
prompt = prompt.replace('{FUNCTION_SOURCE}', func_source)
@@ -967,9 +964,9 @@ def build_prompt(self) -> prompts.Prompt:
if not xrefs:
logger.error('No cross references found for project: %s, function: %s',
self.benchmark.project, self.benchmark.function_signature)
- prompt = prompt.replace(
- '\n{FUNCTION_REFERENCES}\n}',
- '')
+ prompt = prompt.replace('', '')\
+ .replace('{FUNCTION_REFERENCES}', '')\
+ .replace('', '')
else:
references = [f"\n{xref}\n" for xref in xrefs]
references_str = '\n'.join(references)
diff --git a/prompts/agent/function-analyzer-description.txt b/prompts/agent/function-analyzer-description.txt
new file mode 100644
index 000000000..484afc037
--- /dev/null
+++ b/prompts/agent/function-analyzer-description.txt
@@ -0,0 +1 @@
+Extracts a function's requirements from its source implementation.
\ No newline at end of file
diff --git a/prompts/agent/function-analyzer-instruction.txt b/prompts/agent/function-analyzer-instruction.txt
index e1dc112b9..f6831d050 100644
--- a/prompts/agent/function-analyzer-instruction.txt
+++ b/prompts/agent/function-analyzer-instruction.txt
@@ -1,80 +1,3 @@
-You are a professional security engineer.
-
-Your objective is to analyze the function's implementation using the steps provided and return a response in the expected format.
-The requirements you provide will be used by another agent to generate valid fuzz drivers for the target function.
-
-The function you will analyze is provided below. We have provided the target function, and the implementations of its children functions.
-
-
-{{FUNCTION_SOURCE}}
-
-
-Follow these steps to analyze a function and identify its input requirements:
-
-Step 1: Identify all Fuzzing Crash Indicators (FCI) in the function.
- * Fuzz Crash Indicators are statements that can cause the program to crash if expected conditions are violated.
- * They include assertion statements, array indexing statements, pointer dereferencing statements, memory access statements, string handling statements, etc.
- * Note that some programs can have custom assertion statements, like require() or ensure().
-
-Step 2: Identify the input requirements necessary to ensure the safety of each identified Fuzzing Crash Indicators.
- * Each requirement MUST be precise for it to be useful.
- * You MUST include a one-sentence summary why a specific requirement was included.
- * You should not repeat any requirement, even if it is necessary to satisfy multiple FCIs.
-
-Step 3: Compile the requirements you derived and return in the expected format.
-
-
-
-
-Make sure your response follows the following format, enclosed in ``` ```.
-
-```
-
-
-project name: the name of the project provided
-function signature: The function's signature
-
-
-
-A summary of what the function does.
-
-
-
-
-First requirement
-
-
-Second requirement
-
-...
-
-nth requirement
-
-
-
-
-
-
-
-
-Here is an example response
-
-
-project name: htslib
-function signature: int sam_index_build(const char *, int)
-
-
-
-The sam_index_build function is used to build a sam index. It uses the input arguments to identify and retrieve the index to build. It returns 1 if the build succeeds and 0 if the build fails.
-
-
-
-
-The second argument should be less than 64. This is to prevent an assertion violation in the program.
-
-
-
-
-
-
-
+You are a security engineer tasked with analyzing a function
+ and extracting its input requirements,
+ necessary for it to execute correctly.
\ No newline at end of file
diff --git a/prompts/agent/function-analyzer-priming.txt b/prompts/agent/function-analyzer-priming.txt
index 6a289e8b1..86d3baa76 100644
--- a/prompts/agent/function-analyzer-priming.txt
+++ b/prompts/agent/function-analyzer-priming.txt
@@ -1,17 +1,30 @@
-You are a professional security engineer working on creating a valid fuzzing driver for the target function `{FUNCTION_SIGNATURE}` in the project {PROJECT_NAME}.
-We will provide you with the implementation of the target function, implementations of other functions that reference the target function, and a set of tools that you can use to get additional function implementations and context information.
-
-Your goal is to analyze the provided functions and its usages, provide a clear detailed description of the function, and identify the important input requirements for the target function to execute correctly.
-
-The requirements we are interested in include the following:
-5. WHat constraints on input arguments is necessary to prevent assertion failures, out-of-bound array indexing, null pointer dereferencing, invalid memory access, invalid string access, and other crashes.
-1. What setup functions must be called before the target function?
-2. What existing function in the project should we use to create valid inputs for the target function?
-3. What inputs, or members of an input, should we initialize with random fuzz data?
-4. What inputs must we initialize by calling another existing function?
-
-Keep your responses concise. Each requirement should contain two sentences. The first is the requirement. The second is a brief reason why it is important.
+
+You are a professional security engineer identifying the input requirements for the target function `{FUNCTION_SIGNATURE}` in the project {PROJECT_NAME}.
+We will provide you with the implementation of the target function, implementations of functions that reference the target function, and a set of tools that you can use to get additional context information about the target function.
+Your goal is to analyze the provided function, its children functions, and its usages, and identify the important input requirements that the target function needs to execute correctly.
+
+
+
+ We are interested in only the following kinds of requirements.
+ - Input requirements that are necessary to prevent program crashes.
+ * Program crashes can be caused by assertion failures, invalid array indexing, out-of-bound memory accesses, pointer dereferencing failures.
+ - Requirements for creating valid input arguments.
+ * Here, you should mention what existing function or functions should be used to create a valid input argument.
+ * For example, if a function takes in an integer argument but uses that argument as a file descriptor for reading a fil (eg the read function), then it implies the integer must have been returned by another function that creates a file descriptor (eg the open function).Add commentMore actions
+ * Similarly, if a function takes in a character pointer and uses it like a file path or name, then this implies a valid file should be created and the path or name passed to this function.
+ * Also, if a function takes in a pointer argument and uses that argument as an argument to strlen or strcpy or other string handling function, this implies the function expects a null-terminated string.
+ - Relationship between inputs
+ * For example, this can be the relationship between a pointer and an integer argument representing its size.
+ - Input variables that should be fuzzed
+ * What input variables can be user-controlled or contain invalid values?
+ * For example, if a function parses or processes one of its input arguments, then that argument is fuzzable.
+ - Setup functions to call before the target function can be called.
+ * This is the function or set of functions we must call before calling the targte function.
+ * For example, if a function depends on a global variable which is set by another function, this may imply we need to call that function before the target function.
+
+Keep each requirement concise. Each requirement should contain two sentences. The first is the requirement. The second is a brief reason why it is important.
+
Here is the provided data.
@@ -27,9 +40,8 @@ Here is the provided data.
You MUST return your response in the format below.
-Make sure your response follows the following format, enclosed in ``` ```.
+Make sure your response follows the following format.
-```
project name: the name of the project provided
@@ -84,10 +96,11 @@ The third argument should be greater than zero. This is to prevent an assertion
The third argument should be less than 16. This is to prevent an out-of-bound array access when the argument is used to index the fixed-size array `stores`.
-
You will be provided with the following tools.
-1. _function_source_with_name: Use this tool to retrieve the implementation of a function. You will invoke the tool using the project's name and function's name as arguments.
\ No newline at end of file
+1. get_function_implementation: This is a tool you can use to retrieve the implementation of a function using the project's name and function's name as arguments.
+2. search_project_files: This is an interactive tool you can use to search the project's source file using bash commands and find definitions or usages of functions, classes, structs, and variables.
+ The usage guide for the Bash Tool is provided below.
diff --git a/stage/writing_stage.py b/stage/writing_stage.py
index 7ae28836e..805abc6d6 100644
--- a/stage/writing_stage.py
+++ b/stage/writing_stage.py
@@ -58,6 +58,7 @@ def execute(self, result_history: list[Result]) -> Result:
agent = self.get_agent(index=0)
if agent.name == 'FunctionAnalyzer':
agent_result = self._execute_agent(agent, result_history)
+ self.logger.write_chat_history(agent_result)
result_history.append(agent_result)
# Then, execute the Prototyper agent to refine the fuzz target.
diff --git a/tool/fuzz_introspector_tool.py b/tool/fuzz_introspector_tool.py
index a56cea09c..5be8e14ec 100644
--- a/tool/fuzz_introspector_tool.py
+++ b/tool/fuzz_introspector_tool.py
@@ -79,14 +79,11 @@ def function_source_with_signature(self, project_name: str,
return function_code
- def function_source_with_name(self, project_name: str,
- function_name: str) -> str:
+ def get_function_implementation(self, project_name: str,
+ function_name: str) -> str:
"""
Retrieves a function's source from the fuzz introspector API,
- using the project's name and function's name.
- This function first retrieves the list of all
- functions in the project, so it can get the function's signature.
- Then it uses the function's signature to retrieve the source code.
+ using the project's name and function's name
Args:
project_name (str): The name of the project.