diff --git a/agent/base_agent.py b/agent/base_agent.py index 97325be35..12d593b8a 100644 --- a/agent/base_agent.py +++ b/agent/base_agent.py @@ -13,6 +13,7 @@ # limitations under the License. """The abstract base class for LLM agents in stages.""" import argparse +import asyncio import os import random import re @@ -23,11 +24,14 @@ from typing import Any, Optional import requests +from google.adk import agents, runners, sessions +from google.genai import errors, types import logger import utils from data_prep import introspector -from llm_toolkit.models import LLM +from experiment import benchmark as benchmarklib +from llm_toolkit.models import LLM, VertexAIModel from llm_toolkit.prompts import Prompt from results import Result from tool.base_tool import BaseTool @@ -295,6 +299,107 @@ def execute(self, result_history: list[Result]) -> Result: """Executes the agent based on previous result.""" +class ADKBaseAgent(BaseAgent): + """The abstract base class for agents created using the ADK library.""" + + def __init__(self, + trial: int, + llm: LLM, + args: argparse.Namespace, + benchmark: benchmarklib.Benchmark, + description: str = '', + instruction: str = '', + tools: Optional[list] = None, + name: str = ''): + + super().__init__(trial, llm, args, tools, name) + + self.benchmark = benchmark + + # For now, ADKBaseAgents only support the Vertex AI Models. + if not isinstance(llm, VertexAIModel): + raise ValueError(f'{self.name} only supports Vertex AI models.') + + # Create the agent using the ADK library + adk_agent = agents.LlmAgent( + name=self.name, + model=llm._vertex_ai_model, + description=description, + instruction=instruction, + tools=tools or [], + ) + + # Create the session service + session_service = sessions.InMemorySessionService() + session_service.create_session( + app_name=self.name, + user_id=benchmark.id, + session_id=f'session_{self.trial}', + ) + + # Create the runner + self.runner = runners.Runner( + agent=adk_agent, + app_name=self.name, + session_service=session_service, + ) + + self.round = 0 + + logger.info('ADK Agent %s created.', self.name, trial=self.trial) + + def chat_llm(self, cur_round: int, client: Any, prompt: Prompt, + trial: int) -> str: + """Call the agent with the given prompt, running async code in sync.""" + + self.round = cur_round + + self.log_llm_prompt(prompt.get()) + + async def _call(): + user_id = self.benchmark.id + session_id = f"session_{self.trial}" + content = types.Content(role='user', + parts=[types.Part(text=prompt.get())]) + + final_response_text = '' + + async for event in self.runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=content, + ): + if event.is_final_response(): + if (event.content and event.content.parts and + event.content.parts[0].text): + final_response_text = event.content.parts[0].text + elif event.actions and event.actions.escalate: + error_message = event.error_message + logger.error('Agent escalated: %s', error_message, trial=self.trial) + + self.log_llm_response(final_response_text) + + return final_response_text + + return self.llm.with_retry_on_error(lambda: asyncio.run(_call()), + [errors.ClientError]) + + def log_llm_prompt(self, promt: str) -> None: + self.round += 1 + logger.info('%s', + self.round, + promt, + self.round, + trial=self.trial) + + def log_llm_response(self, response: str) -> None: + logger.info('%s', + self.round, + response, + self.round, + trial=self.trial) + + if __name__ == "__main__": # For cloud experiments. BaseAgent.cloud_main() diff --git a/agent/function_analyzer.py b/agent/function_analyzer.py index fbed7310b..ef538b954 100644 --- a/agent/function_analyzer.py +++ b/agent/function_analyzer.py @@ -18,23 +18,20 @@ """ import argparse -import asyncio import os from typing import Optional -from google.adk import agents, runners, sessions -from google.genai import types - import logger import results as resultslib from agent import base_agent +from data_prep import introspector from experiment import benchmark as benchmarklib from experiment.workdir import WorkDirs from llm_toolkit import models, prompt_builder, prompts -from tool import base_tool, fuzz_introspector_tool +from tool import container_tool -class FunctionAnalyzer(base_agent.BaseAgent): +class FunctionAnalyzer(base_agent.ADKBaseAgent): """An LLM agent to analyze a function and identify its implicit requirements. The results of this analysis will be used by the writer agents to generate correct fuzz target for the function. @@ -44,117 +41,52 @@ def __init__(self, trial: int, llm: models.LLM, args: argparse.Namespace, - benchmark: benchmarklib.Benchmark, - tools: Optional[list[base_tool.BaseTool]] = None, + benchmark: benchmarklib.Benchmark,\ name: str = ''): - # Ensure the llm is an instance of VertexAIModel - # TODO (pamusuo): Provide support for other LLM models - if not isinstance(llm, models.VertexAIModel): - raise ValueError( - "FunctionAnalyzer agent requires a VertexAIModel instance for llm.") - - super().__init__(trial, llm, args, tools, name) - - self.vertex_ai_model = llm._vertex_ai_model - self.benchmark = benchmark - - self.initialize() - - def initialize(self): - """Initialize the function analyzer agent with the given benchmark.""" - - # Initialize the Fuzz Introspector tool - introspector_tool = fuzz_introspector_tool.FuzzIntrospectorTool( - self.benchmark, self.name) - - # Create the agent using the ADK library - # TODO(pamusuo): Create another AdkBaseAgent that extends - # BaseAgent and initializes an ADK agent as well. - function_analyzer = agents.LlmAgent( - name="FunctionAnalyzer", - model=self.vertex_ai_model, - description="""Extracts a function's requirements - from its source implementation.""", - instruction= - """You are a security engineer tasked with analyzing a function - and extracting its input requirements, - necessary for it to execute correctly.""", - tools=[introspector_tool.function_source_with_name], - ) - - # Create the session service - session_service = sessions.InMemorySessionService() - session_service.create_session( - app_name=self.name, - user_id=self.benchmark.id, - session_id=f"session_{self.trial}", - ) - - # Create the runner - self.runner = runners.Runner( - agent=function_analyzer, - app_name=self.name, - session_service=session_service, - ) - - logger.info("Function Analyzer Agent created, with name: %s", - self.name, - trial=self.trial) - - async def call_agent(self, query: str, runner: runners.Runner, user_id: str, - session_id: str) -> str: - """Call the agent asynchronously with the given query.""" - - content = types.Content(role='user', parts=[types.Part(text=query)]) - - final_response_text = '' - - result_available = False + builder = prompt_builder.FunctionAnalyzerTemplateBuilder(llm, benchmark) - async for event in runner.run_async( - user_id=user_id, - session_id=session_id, - new_message=content, - ): + description = builder.get_description().get() - if event.is_final_response(): - if (event.content and event.content.parts and - event.content.parts[0].text): - final_response_text = event.content.parts[0].text - result_available = True - elif event.actions and event.actions.escalate: - error_message = event.error_message - logger.error("Agent escalated: %s", error_message, trial=self.trial) + instruction = builder.get_instruction().get() - logger.info("<<< Agent response: %s", final_response_text, trial=self.trial) + tools = [self.get_function_implementation, self.search_project_files] - if result_available and self._parse_tag(final_response_text, 'response'): - # Get the requirements from the response - result_str = self._parse_tag(final_response_text, 'response') - else: - result_str = '' + super().__init__(trial, llm, args, benchmark, description, instruction, + tools, name) - return result_str + self.project_functions = None def write_requirements_to_file(self, args, requirements: str) -> str: """Write the requirements to a file.""" if not requirements: - logger.warning("No requirements to write to file.", trial=self.trial) + logger.warning('No requirements to write to file.', trial=self.trial) return '' requirement_path = os.path.join(args.work_dirs.requirements, - f"{self.benchmark.id}.txt") + f'{self.benchmark.id}.txt') with open(requirement_path, 'w') as f: f.write(requirements) - logger.info("Requirements written to %s", + logger.info('Requirements written to %s', requirement_path, trial=self.trial) return requirement_path + def handle_llm_response(self, final_response_text: str, + result: resultslib.Result) -> None: + """Handle the LLM response and update the result.""" + + result_str = self._parse_tag(final_response_text, 'response') + requirements = self._parse_tag(result_str, 'requirements') + if requirements: + # Write the requirements to a file + requirement_path = self.write_requirements_to_file(self.args, result_str) + function_analysis = resultslib.FunctionAnalysisResult(requirement_path) + result.function_analysis = function_analysis + def execute(self, result_history: list[resultslib.Result]) -> resultslib.Result: """Execute the agent with the given results.""" @@ -168,29 +100,21 @@ def execute(self, work_dirs=self.args.work_dirs, ) + # Initialize the ProjectContainerTool for local file search + self.inspect_tool = container_tool.ProjectContainerTool(self.benchmark) + self.inspect_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null') + # Call the agent asynchronously and return the result prompt = self._initial_prompt(result_history) - query = prompt.gettext() - # Validate query is not empty - if not query.strip(): - logger.error( - "Error occurred while building initial prompt. Cannot call the agent.", - trial=self.trial) - return result - - logger.info("Initial prompt created. Calling LLM...", trial=self.trial) + final_response_text = self.chat_llm(self.round, + client=None, + prompt=prompt, + trial=result_history[-1].trial) - user_id = self.benchmark.id - session_id = f"session_{self.trial}" - result_str = asyncio.run( - self.call_agent(query, self.runner, user_id, session_id)) + self.handle_llm_response(final_response_text, result) - if result_str: - # Write the requirements to a file - requirement_path = self.write_requirements_to_file(self.args, result_str) - function_analysis = resultslib.FunctionAnalysisResult(requirement_path) - result.function_analysis = function_analysis + self.inspect_tool.terminate() return result @@ -205,4 +129,105 @@ def _initial_prompt( prompt = builder.build_prompt() + prompt.append(self.inspect_tool.tutorial()) + return prompt + + def search_project_files(self, request: str) -> str: + """ + This function tool uses bash commands to search the project's source files, + and retrieve requested code snippets or file contents. + Args: + request (str): The bash command to execute and its justification, + formatted using the and tags. + Returns: + str: The response from executing the bash commands, + formatted using the , and tags. + """ + + self.log_llm_response(request) + + prompt = prompt_builder.DefaultTemplateBuilder(self.llm, None).build([]) + + if request: + prompt = self._container_handle_bash_commands(request, self.inspect_tool, + prompt) + + # Finally check invalid request. + if not request or not prompt.get(): + prompt = self._container_handle_invalid_tool_usage( + self.inspect_tool, 0, request, prompt) + + tool_response = prompt.get() + + self.log_llm_prompt(tool_response) + + return tool_response + + def get_function_implementation(self, project_name: str, + function_name: str) -> str: + """ + Retrieves a function's source from the fuzz introspector API, + using the project's name and function's name + + Args: + project_name (str): The name of the project. + function_name (str): The name of the function. + + Returns: + str: Source code of the function if found, otherwise an empty string. + """ + request = f""" + Requesting implementation for the function: + Function name: {function_name} + Project name: {project_name} + """ + + self.log_llm_response(request) + + if self.project_functions is None: + logger.info( + 'Project functions not initialized. Initializing for project "%s".', + project_name, + trial=self.trial) + functions_list = introspector.query_introspector_all_functions( + project_name) + + if functions_list: + self.project_functions = { + func['debug_summary']['name']: func + for func in functions_list + if isinstance(func.get('debug_summary'), dict) and + isinstance(func['debug_summary'].get('name'), str) and + func['debug_summary']['name'].strip() + } + else: + self.project_functions = None + + response = f""" + Project name: {project_name} + Function name: {function_name} + """ + function_source = '' + + if self.project_functions: + function_dict = self.project_functions.get(function_name, {}) + function_signature = function_dict.get('function_signature', '') + + function_source = introspector.query_introspector_function_source( + project_name, function_signature) + + if function_source.strip(): + response += f""" + Function source code: + {function_source} + """ + else: + logger.error('Error: Function with name "%s" not found in project "%s".', + function_name, + project_name, + trial=self.trial) + + self.log_llm_prompt(response) + + return response diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py index 2fed2d76a..b9fcf5d40 100644 --- a/llm_toolkit/prompt_builder.py +++ b/llm_toolkit/prompt_builder.py @@ -901,36 +901,34 @@ def __init__(self, super().__init__(model, benchmark, template_dir, initial) # Load templates. - self.function_analyzer_instruction_template_file = self._find_template( + self.function_analyzer_instruction_file = self._find_template( AGENT_TEMPLATE_DIR, 'function-analyzer-instruction.txt') - self.context_retrieve_template_file = self._find_template( - AGENT_TEMPLATE_DIR, 'context-retriever-instruction.txt') + self.function_analyzer_description_file = self._find_template( + AGENT_TEMPLATE_DIR, 'function-analyzer-description.txt') self.function_analyzer_prompt_template_file = self._find_template( AGENT_TEMPLATE_DIR, 'function-analyzer-priming.txt') - def build_instruction(self) -> prompts.Prompt: + def get_instruction(self) -> prompts.Prompt: """Constructs a prompt using the templates in |self| and saves it.""" self._prompt = self._model.prompt_type()(None) if not self.benchmark: return self._prompt - prompt = self._get_template( - self.function_analyzer_instruction_template_file) + prompt = self._get_template(self.function_analyzer_instruction_file) self._prompt.append(prompt) return self._prompt - def build_context_retriever_instruction(self) -> prompts.Prompt: + def get_description(self) -> prompts.Prompt: """Constructs a prompt using the templates in |self| and saves it.""" self._prompt = self._model.prompt_type()(None) - if not self.benchmark: return self._prompt - prompt = self._get_template(self.context_retrieve_template_file) + prompt = self._get_template(self.function_analyzer_description_file) self._prompt.append(prompt) @@ -957,7 +955,6 @@ def build_prompt(self) -> prompts.Prompt: if not func_source: logger.error('No function source found for project: %s, function: %s', self.benchmark.project, self.benchmark.function_signature) - return prompts.TextPrompt() prompt = prompt.replace('{FUNCTION_SOURCE}', func_source) @@ -967,9 +964,9 @@ def build_prompt(self) -> prompts.Prompt: if not xrefs: logger.error('No cross references found for project: %s, function: %s', self.benchmark.project, self.benchmark.function_signature) - prompt = prompt.replace( - '\n{FUNCTION_REFERENCES}\n}', - '') + prompt = prompt.replace('', '')\ + .replace('{FUNCTION_REFERENCES}', '')\ + .replace('', '') else: references = [f"\n{xref}\n" for xref in xrefs] references_str = '\n'.join(references) diff --git a/prompts/agent/function-analyzer-description.txt b/prompts/agent/function-analyzer-description.txt new file mode 100644 index 000000000..484afc037 --- /dev/null +++ b/prompts/agent/function-analyzer-description.txt @@ -0,0 +1 @@ +Extracts a function's requirements from its source implementation. \ No newline at end of file diff --git a/prompts/agent/function-analyzer-instruction.txt b/prompts/agent/function-analyzer-instruction.txt index e1dc112b9..f6831d050 100644 --- a/prompts/agent/function-analyzer-instruction.txt +++ b/prompts/agent/function-analyzer-instruction.txt @@ -1,80 +1,3 @@ -You are a professional security engineer. - -Your objective is to analyze the function's implementation using the steps provided and return a response in the expected format. -The requirements you provide will be used by another agent to generate valid fuzz drivers for the target function. - -The function you will analyze is provided below. We have provided the target function, and the implementations of its children functions. - - -{{FUNCTION_SOURCE}} - - -Follow these steps to analyze a function and identify its input requirements: - -Step 1: Identify all Fuzzing Crash Indicators (FCI) in the function. - * Fuzz Crash Indicators are statements that can cause the program to crash if expected conditions are violated. - * They include assertion statements, array indexing statements, pointer dereferencing statements, memory access statements, string handling statements, etc. - * Note that some programs can have custom assertion statements, like require() or ensure(). - -Step 2: Identify the input requirements necessary to ensure the safety of each identified Fuzzing Crash Indicators. - * Each requirement MUST be precise for it to be useful. - * You MUST include a one-sentence summary why a specific requirement was included. - * You should not repeat any requirement, even if it is necessary to satisfy multiple FCIs. - -Step 3: Compile the requirements you derived and return in the expected format. - - - - -Make sure your response follows the following format, enclosed in ``` ```. - -``` - - -project name: the name of the project provided -function signature: The function's signature - - - -A summary of what the function does. - - - - -First requirement - - -Second requirement - -... - -nth requirement - - - - - - - - -Here is an example response - - -project name: htslib -function signature: int sam_index_build(const char *, int) - - - -The sam_index_build function is used to build a sam index. It uses the input arguments to identify and retrieve the index to build. It returns 1 if the build succeeds and 0 if the build fails. - - - - -The second argument should be less than 64. This is to prevent an assertion violation in the program. - - - - - - - +You are a security engineer tasked with analyzing a function + and extracting its input requirements, + necessary for it to execute correctly. \ No newline at end of file diff --git a/prompts/agent/function-analyzer-priming.txt b/prompts/agent/function-analyzer-priming.txt index 6a289e8b1..86d3baa76 100644 --- a/prompts/agent/function-analyzer-priming.txt +++ b/prompts/agent/function-analyzer-priming.txt @@ -1,17 +1,30 @@ -You are a professional security engineer working on creating a valid fuzzing driver for the target function `{FUNCTION_SIGNATURE}` in the project {PROJECT_NAME}. -We will provide you with the implementation of the target function, implementations of other functions that reference the target function, and a set of tools that you can use to get additional function implementations and context information. - -Your goal is to analyze the provided functions and its usages, provide a clear detailed description of the function, and identify the important input requirements for the target function to execute correctly. - -The requirements we are interested in include the following: -5. WHat constraints on input arguments is necessary to prevent assertion failures, out-of-bound array indexing, null pointer dereferencing, invalid memory access, invalid string access, and other crashes. -1. What setup functions must be called before the target function? -2. What existing function in the project should we use to create valid inputs for the target function? -3. What inputs, or members of an input, should we initialize with random fuzz data? -4. What inputs must we initialize by calling another existing function? - -Keep your responses concise. Each requirement should contain two sentences. The first is the requirement. The second is a brief reason why it is important. + +You are a professional security engineer identifying the input requirements for the target function `{FUNCTION_SIGNATURE}` in the project {PROJECT_NAME}. +We will provide you with the implementation of the target function, implementations of functions that reference the target function, and a set of tools that you can use to get additional context information about the target function. +Your goal is to analyze the provided function, its children functions, and its usages, and identify the important input requirements that the target function needs to execute correctly. + + + + We are interested in only the following kinds of requirements. + - Input requirements that are necessary to prevent program crashes. + * Program crashes can be caused by assertion failures, invalid array indexing, out-of-bound memory accesses, pointer dereferencing failures. + - Requirements for creating valid input arguments. + * Here, you should mention what existing function or functions should be used to create a valid input argument. + * For example, if a function takes in an integer argument but uses that argument as a file descriptor for reading a fil (eg the read function), then it implies the integer must have been returned by another function that creates a file descriptor (eg the open function).Add commentMore actions + * Similarly, if a function takes in a character pointer and uses it like a file path or name, then this implies a valid file should be created and the path or name passed to this function. + * Also, if a function takes in a pointer argument and uses that argument as an argument to strlen or strcpy or other string handling function, this implies the function expects a null-terminated string. + - Relationship between inputs + * For example, this can be the relationship between a pointer and an integer argument representing its size. + - Input variables that should be fuzzed + * What input variables can be user-controlled or contain invalid values? + * For example, if a function parses or processes one of its input arguments, then that argument is fuzzable. + - Setup functions to call before the target function can be called. + * This is the function or set of functions we must call before calling the targte function. + * For example, if a function depends on a global variable which is set by another function, this may imply we need to call that function before the target function. + +Keep each requirement concise. Each requirement should contain two sentences. The first is the requirement. The second is a brief reason why it is important. + Here is the provided data. @@ -27,9 +40,8 @@ Here is the provided data. You MUST return your response in the format below. -Make sure your response follows the following format, enclosed in ``` ```. +Make sure your response follows the following format. -``` project name: the name of the project provided @@ -84,10 +96,11 @@ The third argument should be greater than zero. This is to prevent an assertion The third argument should be less than 16. This is to prevent an out-of-bound array access when the argument is used to index the fixed-size array `stores`. - You will be provided with the following tools. -1. _function_source_with_name: Use this tool to retrieve the implementation of a function. You will invoke the tool using the project's name and function's name as arguments. \ No newline at end of file +1. get_function_implementation: This is a tool you can use to retrieve the implementation of a function using the project's name and function's name as arguments. +2. search_project_files: This is an interactive tool you can use to search the project's source file using bash commands and find definitions or usages of functions, classes, structs, and variables. + The usage guide for the Bash Tool is provided below. diff --git a/stage/writing_stage.py b/stage/writing_stage.py index 7ae28836e..805abc6d6 100644 --- a/stage/writing_stage.py +++ b/stage/writing_stage.py @@ -58,6 +58,7 @@ def execute(self, result_history: list[Result]) -> Result: agent = self.get_agent(index=0) if agent.name == 'FunctionAnalyzer': agent_result = self._execute_agent(agent, result_history) + self.logger.write_chat_history(agent_result) result_history.append(agent_result) # Then, execute the Prototyper agent to refine the fuzz target. diff --git a/tool/fuzz_introspector_tool.py b/tool/fuzz_introspector_tool.py index a56cea09c..5be8e14ec 100644 --- a/tool/fuzz_introspector_tool.py +++ b/tool/fuzz_introspector_tool.py @@ -79,14 +79,11 @@ def function_source_with_signature(self, project_name: str, return function_code - def function_source_with_name(self, project_name: str, - function_name: str) -> str: + def get_function_implementation(self, project_name: str, + function_name: str) -> str: """ Retrieves a function's source from the fuzz introspector API, - using the project's name and function's name. - This function first retrieves the list of all - functions in the project, so it can get the function's signature. - Then it uses the function's signature to retrieve the source code. + using the project's name and function's name Args: project_name (str): The name of the project.