Skip to content

Commit

Permalink
Greatly improved execution
Browse files Browse the repository at this point in the history
  • Loading branch information
A-F-V committed Feb 21, 2024
1 parent a34bcd3 commit 12b14f6
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 27 deletions.
2 changes: 1 addition & 1 deletion cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def run(repo: str):
click.echo("Running the demo")

code_path = "edit_distance/edit_distance.py"
query = f"In the file {code_path}, get me line number of the function called `lowest_...` and tell me the return type. Please also give me a bullet point list of code changes to improve the readability of the function."
query = f"In the file {code_path}, get me line number of the function called `lowest_...` and tell me the return type. Create a git diff to remove the docstring."

context = ProjectContext(folder_path=repo)
agent = RefactoringAgent()
Expand Down
58 changes: 58 additions & 0 deletions src/actions/code_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Optional, List

from numpy import diff
from src.actions.action import Action
from src.common import Symbol
from src.common.definitions import Definition, pydantic_to_str

from src.planning.state import RefactoringAgentState
from src.utilities.jedi_utils import (
goto_symbol,
jedi_name_to_symbol,
span_to_snippet,
symbol_to_definition,
)
from src.utilities.paths import add_path_to_prefix, remove_path_prefix
from ..common import ProjectContext, Symbol

from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import tool

import jedi
import os


###########################################
# Tools
example_diff = """
-Hello, world!
+Hello, universe!
This is an example
-of a unified diff format.
+of the unified diff format."""


class ApplyChangeInput(BaseModel):
file: str = Field(description="The file to apply the change to.")
hunk_header: str = Field(description="The hunk header.")
diff: str = Field(description=f"The diff to apply. An example is {example_diff}")


def create_apply_change():

def apply_change(state: RefactoringAgentState, args: ApplyChangeInput) -> str:
diff_obj = f"""--- {args.file}
+++ {args.file}
{args.hunk_header}
{args.diff}
"""
print(diff_obj)
# Invalidate the code snippets
return diff_obj

return Action(
id="apply_code_change",
description="Applies a change to a file.",
model_cls=ApplyChangeInput,
f=apply_change,
)
9 changes: 2 additions & 7 deletions src/actions/code_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import json
from re import S
import stat
from typing import Optional, List
from src.actions.action import Action
from src.common import Symbol
Expand Down Expand Up @@ -38,9 +35,7 @@ class SearchInput(BaseModel):

def create_code_search():

def code_search(
state: RefactoringAgentState, args: SearchInput
) -> List[Definition]:
def code_search(state: RefactoringAgentState, args: SearchInput) -> str:
context = state["project_context"]
folder_path = context.folder_path
query = args.query
Expand All @@ -66,7 +61,7 @@ def code_search(
for definition in definitions:
state["code_blocks"].append(definition.span)

return definitions
return f"Found {len(definitions)} definitions for {query}. Stored under the '<Code Snippets>' block."

return Action(
id="code_search",
Expand Down
4 changes: 3 additions & 1 deletion src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END
from src.actions.code_inspection import create_code_loader
from src.actions.code_manipulation import create_apply_change
from src.actions.code_search import create_definition_gotoer
from src.actions.code_search import create_code_search
from src.planning.planner import LLMExecutor, ShouldContinue, Planner, Thinker
Expand All @@ -23,6 +24,7 @@ def _create_refactoring_actions(self):
action_list = ActionDispatcher()
# Code Querying & Manipulation
action_list.register_action(create_code_search())
action_list.register_action(create_apply_change())
# action_list.register_action(create_definition_gotoer())
# action_list.register_action(create_code_loader())
# Git
Expand Down Expand Up @@ -61,5 +63,5 @@ def run(self, inp: str, context: ProjectContext) -> RefactoringAgentState:
"code_blocks": [],
"thoughts": [],
}
config = RunnableConfig(recursion_limit=50)
config = RunnableConfig(recursion_limit=10)
return RefactoringAgentState(**self.app.invoke(state, config=config))
24 changes: 23 additions & 1 deletion src/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,14 @@ def __init__(
current_task: str,
verbose=True,
additional_instructions=default_instructions,
record_history=True,
):
self.actions = actions
self.llm = ChatOpenAI(model="gpt-4-1106-preview")
self.current_task = current_task
self.verbose = verbose
self.additional_instructions = additional_instructions
self.record_history = record_history
self.create_prompt()
# self.chain = self.prompt_template | self.llm | self.parser

Expand Down Expand Up @@ -175,7 +177,27 @@ def format_context_prompt(self, state: RefactoringAgentState) -> str:
return message_sent

def get_openai_tools(self, state):
tools = map(lambda x: x.to_tool(state), self.actions)
actions = self.actions
if self.record_history:

def wrap_with_history(action: Action):
def wrapped_action(state: RefactoringAgentState, args):
result = action.execute(state, args)
request = ActionRequest(id=action.id, args=args)
state["history"].append(
ActionRecord(request=request, result=result)
)
return result

return Action(
id=action.id,
description=action.description,
model_cls=action.cls,
f=wrapped_action,
)

actions = map(wrap_with_history, actions)
tools = map(lambda x: x.to_tool(state), actions)
# open_ai_tools = map(convert_to_openai_function, tools)
return list(tools)

Expand Down
53 changes: 36 additions & 17 deletions src/planning/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,36 @@ def __call__(self, state: RefactoringAgentState):

class Thinker:
def __init__(self):

def create_thought():
class NewThought(BaseModel):
thought: str = Field(
description="The thought to add to the thoughts list"
)

def thought(state: RefactoringAgentState, args: NewThought):
state["thoughts"].append(args.thought)

action = Action(
id="add_thought",
description="Add a thought to the thoughts list.",
model_cls=NewThought,
f=thought,
)
return action

task = """Reflect on the current state and write a brief thought to help your future self."""
additional_instructions = """After thinking, you will be prompted to select some actions to execute. You should consider what actions you have already executed before and factor that into your advice. You will be given more opportunities to think and execute in the future, so keep your thoughts extremely brief.
For Example:
`Search for function, read the definition, and extract the function signature`
"""
additional_instructions = """Use this as a way to plan your next steps, reflect on what went well and how you can improve. Be incredibly brief (1-2 sentences).
Call the add_thought function to add a thought to the thoughts list. Say 'Done' after you have added your thought."""
self.controller = LLMController(
[], task, additional_instructions=additional_instructions
[create_thought()],
task,
additional_instructions=additional_instructions,
record_history=False,
)

def __call__(self, state: RefactoringAgentState):
_, thought = self.controller.run(state)
state["thoughts"].append(thought)
return state
return self.controller(state)


class NextStep(Enum):
Expand All @@ -76,25 +93,23 @@ class ShouldContinue:
def __init__(self) -> None:
should_continue_action = self._create_should_continue_action()
task = """Decide whether to think & execute again or finish. """
additional_instructions = """Decide ONLY which branch to take next:
- Think-Execute branch
- Finish branch
Do not return reply with any information or reasoning.
Execute exactly one function"""
additional_instructions = """
Call the `should_continue` function with a true boolean to continue thinking & executing, and false to finish. Say 'Done' after you have added your thought.."""
self.controller = LLMController(
[should_continue_action],
task,
additional_instructions=additional_instructions,
record_history=False,
)

def _create_should_continue_action(self):
def should_continue(state: RefactoringAgentState, args: NextStepInput):
if args.should_continue:
self.next_node = "think"
return "Moving to thinking step"
return "Wait for further instructions."
else:
self.next_node = "finish"
return "Moving to finish step"
return "Wait for further instructions."

action = Action(
id="should_continue",
Expand All @@ -113,9 +128,13 @@ def __call__(self, state: RefactoringAgentState):
class LLMExecutor:
def __init__(self, action_list: List[Action]):
task = """Select the next actions to execute."""
additional_instructions = """You will be allowed to execute actions in the future, so do not worry about executing all the actions at once."""
additional_instructions = """You will be allowed to execute actions in the future, so do not worry about executing all the actions at once.
Call any of the available functions. Say 'Done' after you are done invoking functions."""
self.executor = LLMController(
action_list, task, additional_instructions=additional_instructions
action_list,
task,
additional_instructions=additional_instructions,
record_history=True,
)

def __call__(self, state: RefactoringAgentState):
Expand Down

0 comments on commit 12b14f6

Please sign in to comment.