From 12b14f6d402dc00fa83902a086a7cca72e63e796 Mon Sep 17 00:00:00 2001 From: A-F-V Date: Wed, 21 Feb 2024 00:11:20 +0000 Subject: [PATCH] Greatly improved execution --- cli.py | 2 +- src/actions/code_manipulation.py | 58 ++++++++++++++++++++++++++++++++ src/actions/code_search.py | 9 ++--- src/agent.py | 4 ++- src/execution.py | 24 ++++++++++++- src/planning/planner.py | 53 +++++++++++++++++++---------- 6 files changed, 123 insertions(+), 27 deletions(-) create mode 100644 src/actions/code_manipulation.py diff --git a/cli.py b/cli.py index e010faf..30e3d06 100644 --- a/cli.py +++ b/cli.py @@ -41,7 +41,7 @@ def run(repo: str): click.echo("Running the demo") code_path = "edit_distance/edit_distance.py" - query = f"In the file {code_path}, get me line number of the function called `lowest_...` and tell me the return type. Please also give me a bullet point list of code changes to improve the readability of the function." + query = f"In the file {code_path}, get me line number of the function called `lowest_...` and tell me the return type. Create a git diff to remove the docstring." context = ProjectContext(folder_path=repo) agent = RefactoringAgent() diff --git a/src/actions/code_manipulation.py b/src/actions/code_manipulation.py new file mode 100644 index 0000000..8d01519 --- /dev/null +++ b/src/actions/code_manipulation.py @@ -0,0 +1,58 @@ +from typing import Optional, List + +from numpy import diff +from src.actions.action import Action +from src.common import Symbol +from src.common.definitions import Definition, pydantic_to_str + +from src.planning.state import RefactoringAgentState +from src.utilities.jedi_utils import ( + goto_symbol, + jedi_name_to_symbol, + span_to_snippet, + symbol_to_definition, +) +from src.utilities.paths import add_path_to_prefix, remove_path_prefix +from ..common import ProjectContext, Symbol + +from langchain.pydantic_v1 import BaseModel, Field +from langchain.tools import tool + +import jedi +import os + + +########################################### +# Tools +example_diff = """ + -Hello, world! + +Hello, universe! + This is an example + -of a unified diff format. + +of the unified diff format.""" + + +class ApplyChangeInput(BaseModel): + file: str = Field(description="The file to apply the change to.") + hunk_header: str = Field(description="The hunk header.") + diff: str = Field(description=f"The diff to apply. An example is {example_diff}") + + +def create_apply_change(): + + def apply_change(state: RefactoringAgentState, args: ApplyChangeInput) -> str: + diff_obj = f"""--- {args.file} ++++ {args.file} +{args.hunk_header} +{args.diff} +""" + print(diff_obj) + # Invalidate the code snippets + return diff_obj + + return Action( + id="apply_code_change", + description="Applies a change to a file.", + model_cls=ApplyChangeInput, + f=apply_change, + ) diff --git a/src/actions/code_search.py b/src/actions/code_search.py index d217a81..a6ab2ae 100644 --- a/src/actions/code_search.py +++ b/src/actions/code_search.py @@ -1,6 +1,3 @@ -import json -from re import S -import stat from typing import Optional, List from src.actions.action import Action from src.common import Symbol @@ -38,9 +35,7 @@ class SearchInput(BaseModel): def create_code_search(): - def code_search( - state: RefactoringAgentState, args: SearchInput - ) -> List[Definition]: + def code_search(state: RefactoringAgentState, args: SearchInput) -> str: context = state["project_context"] folder_path = context.folder_path query = args.query @@ -66,7 +61,7 @@ def code_search( for definition in definitions: state["code_blocks"].append(definition.span) - return definitions + return f"Found {len(definitions)} definitions for {query}. Stored under the '' block." return Action( id="code_search", diff --git a/src/agent.py b/src/agent.py index b971938..a4071de 100644 --- a/src/agent.py +++ b/src/agent.py @@ -3,6 +3,7 @@ from langchain_openai import ChatOpenAI from langgraph.graph import StateGraph, END from src.actions.code_inspection import create_code_loader +from src.actions.code_manipulation import create_apply_change from src.actions.code_search import create_definition_gotoer from src.actions.code_search import create_code_search from src.planning.planner import LLMExecutor, ShouldContinue, Planner, Thinker @@ -23,6 +24,7 @@ def _create_refactoring_actions(self): action_list = ActionDispatcher() # Code Querying & Manipulation action_list.register_action(create_code_search()) + action_list.register_action(create_apply_change()) # action_list.register_action(create_definition_gotoer()) # action_list.register_action(create_code_loader()) # Git @@ -61,5 +63,5 @@ def run(self, inp: str, context: ProjectContext) -> RefactoringAgentState: "code_blocks": [], "thoughts": [], } - config = RunnableConfig(recursion_limit=50) + config = RunnableConfig(recursion_limit=10) return RefactoringAgentState(**self.app.invoke(state, config=config)) diff --git a/src/execution.py b/src/execution.py index 4a9b56d..f8d39c4 100644 --- a/src/execution.py +++ b/src/execution.py @@ -142,12 +142,14 @@ def __init__( current_task: str, verbose=True, additional_instructions=default_instructions, + record_history=True, ): self.actions = actions self.llm = ChatOpenAI(model="gpt-4-1106-preview") self.current_task = current_task self.verbose = verbose self.additional_instructions = additional_instructions + self.record_history = record_history self.create_prompt() # self.chain = self.prompt_template | self.llm | self.parser @@ -175,7 +177,27 @@ def format_context_prompt(self, state: RefactoringAgentState) -> str: return message_sent def get_openai_tools(self, state): - tools = map(lambda x: x.to_tool(state), self.actions) + actions = self.actions + if self.record_history: + + def wrap_with_history(action: Action): + def wrapped_action(state: RefactoringAgentState, args): + result = action.execute(state, args) + request = ActionRequest(id=action.id, args=args) + state["history"].append( + ActionRecord(request=request, result=result) + ) + return result + + return Action( + id=action.id, + description=action.description, + model_cls=action.cls, + f=wrapped_action, + ) + + actions = map(wrap_with_history, actions) + tools = map(lambda x: x.to_tool(state), actions) # open_ai_tools = map(convert_to_openai_function, tools) return list(tools) diff --git a/src/planning/planner.py b/src/planning/planner.py index 48a9a44..4f62228 100644 --- a/src/planning/planner.py +++ b/src/planning/planner.py @@ -40,19 +40,36 @@ def __call__(self, state: RefactoringAgentState): class Thinker: def __init__(self): + + def create_thought(): + class NewThought(BaseModel): + thought: str = Field( + description="The thought to add to the thoughts list" + ) + + def thought(state: RefactoringAgentState, args: NewThought): + state["thoughts"].append(args.thought) + + action = Action( + id="add_thought", + description="Add a thought to the thoughts list.", + model_cls=NewThought, + f=thought, + ) + return action + task = """Reflect on the current state and write a brief thought to help your future self.""" - additional_instructions = """After thinking, you will be prompted to select some actions to execute. You should consider what actions you have already executed before and factor that into your advice. You will be given more opportunities to think and execute in the future, so keep your thoughts extremely brief. - For Example: - `Search for function, read the definition, and extract the function signature` - """ + additional_instructions = """Use this as a way to plan your next steps, reflect on what went well and how you can improve. Be incredibly brief (1-2 sentences). + Call the add_thought function to add a thought to the thoughts list. Say 'Done' after you have added your thought.""" self.controller = LLMController( - [], task, additional_instructions=additional_instructions + [create_thought()], + task, + additional_instructions=additional_instructions, + record_history=False, ) def __call__(self, state: RefactoringAgentState): - _, thought = self.controller.run(state) - state["thoughts"].append(thought) - return state + return self.controller(state) class NextStep(Enum): @@ -76,25 +93,23 @@ class ShouldContinue: def __init__(self) -> None: should_continue_action = self._create_should_continue_action() task = """Decide whether to think & execute again or finish. """ - additional_instructions = """Decide ONLY which branch to take next: - - Think-Execute branch - - Finish branch - Do not return reply with any information or reasoning. - Execute exactly one function""" + additional_instructions = """ + Call the `should_continue` function with a true boolean to continue thinking & executing, and false to finish. Say 'Done' after you have added your thought..""" self.controller = LLMController( [should_continue_action], task, additional_instructions=additional_instructions, + record_history=False, ) def _create_should_continue_action(self): def should_continue(state: RefactoringAgentState, args: NextStepInput): if args.should_continue: self.next_node = "think" - return "Moving to thinking step" + return "Wait for further instructions." else: self.next_node = "finish" - return "Moving to finish step" + return "Wait for further instructions." action = Action( id="should_continue", @@ -113,9 +128,13 @@ def __call__(self, state: RefactoringAgentState): class LLMExecutor: def __init__(self, action_list: List[Action]): task = """Select the next actions to execute.""" - additional_instructions = """You will be allowed to execute actions in the future, so do not worry about executing all the actions at once.""" + additional_instructions = """You will be allowed to execute actions in the future, so do not worry about executing all the actions at once. + Call any of the available functions. Say 'Done' after you are done invoking functions.""" self.executor = LLMController( - action_list, task, additional_instructions=additional_instructions + action_list, + task, + additional_instructions=additional_instructions, + record_history=True, ) def __call__(self, state: RefactoringAgentState):