Skip to content

Commit

Permalink
Ready to do eval
Browse files Browse the repository at this point in the history
  • Loading branch information
A-F-V committed Feb 21, 2024
1 parent 12b14f6 commit 320c3ba
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 41 deletions.
2 changes: 1 addition & 1 deletion cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def run(repo: str):
click.echo("Running the demo")

code_path = "edit_distance/edit_distance.py"
query = f"In the file {code_path}, get me line number of the function called `lowest_...` and tell me the return type. Create a git diff to remove the docstring."
query = f"In the file {code_path}, there is function called `lowest_...`. Edit the function by using better names for the variables. Do not rename the function"

context = ProjectContext(folder_path=repo)
agent = RefactoringAgent()
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ python-dotenv

# Git Manipulation
pygit2

diff_match_patch
# LSP
jedi

5 changes: 3 additions & 2 deletions src/actions/action.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from re import I
from sre_constants import SUCCESS
from typing import Optional, TypeVar, Generic, Callable, Type, TypedDict
from unittest.mock import Base
from src.common.definitions import FailureReason
Expand Down Expand Up @@ -86,3 +84,6 @@ def tool_f(**kwargs) -> ActionReturnType:

def __str__(self):
return f"""{{"name": '{self.id}', "description": '{self.description}', "parameters": {self.cls.schema()['properties']}}}"""


# TODO: Fast Exit Action decorates the action with a 'Say Done' Message
105 changes: 81 additions & 24 deletions src/actions/code_manipulation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Optional, List

from numpy import diff
from regex import E
from src.actions.action import Action
from src.common import Symbol
from src.common.definitions import Definition, pydantic_to_str
from src.common.definitions import (
CodeChange,
CodeSnippet,
CodeSpan,
Definition,
pydantic_to_str,
)

from src.planning.state import RefactoringAgentState
from src.utilities.jedi_utils import (
Expand All @@ -12,47 +19,97 @@
span_to_snippet,
symbol_to_definition,
)
from src.utilities.paths import add_path_to_prefix, remove_path_prefix
from src.utilities.paths import (
add_path_to_prefix,
remove_path_prefix,
standardize_code_path,
)
from ..common import ProjectContext, Symbol

from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import tool

import jedi
import os

import pygit2 as git

###########################################
# Tools
example_diff = """
-Hello, world!
+Hello, universe!
This is an example
-of a unified diff format.
+of the unified diff format."""


class ApplyChangeInput(BaseModel):
def apply_change_to_file(context: ProjectContext, change: CodeChange):
# Load the file
path = add_path_to_prefix(context.folder_path, change.file)
with open(path, "r") as file:
code = file.readlines()

# Apply the change
start = change.start_line - 1
end = change.end_line - 1 # File index to 0 index
replacement_code = change.replacement_code.split("\n")
replacement_code = [f"{line}\n" for line in replacement_code]
code[start:end] = replacement_code

# Write the change
with open(path, "w") as file:
file.write("".join(code))

return "Change applied"


def apply_change_to_snippets(change: CodeChange, span: CodeSpan):
# If snippet is in a different file, return
diff_len = len(change.replacement_code.split("\n"))
if change.replacement_code == "":
diff_len = 0
change_in_lines = diff_len - (change.end_line - change.start_line)
s_path = standardize_code_path(span.file_path)
c_path = standardize_code_path(change.file)
if s_path != c_path:
return

# if change is below span, nothing to do
if span.end_line < change.start_line:
return

# if the change is wholly within the span, update the span
if span.start_line <= change.start_line and span.end_line > change.end_line:
# Move the end line by the difference in lines
span.end_line += change_in_lines
assert span.end_line >= span.start_line
return


class CodeChangeInput(BaseModel):
file: str = Field(description="The file to apply the change to.")
hunk_header: str = Field(description="The hunk header.")
diff: str = Field(description=f"The diff to apply. An example is {example_diff}")
start_line: int = Field(description="The start line of the change.")
end_line: int = Field(description="The end line of the change exclusive.")
replacement_code: str = Field(description="The replacement code to apply.")
change_summary: str = Field(description="A summary of the change")


def create_apply_change():

def apply_change(state: RefactoringAgentState, args: ApplyChangeInput) -> str:
diff_obj = f"""--- {args.file}
+++ {args.file}
{args.hunk_header}
{args.diff}
"""
print(diff_obj)
# Invalidate the code snippets
return diff_obj
def apply_change(state: RefactoringAgentState, args: CodeChangeInput) -> str:
change = CodeChange(
file=args.file,
start_line=args.start_line,
end_line=args.end_line,
replacement_code=args.replacement_code,
)

print(pydantic_to_str(args, "ApplyChangeInput"))

# Apply the change to the file
apply_change_to_file(state["project_context"], change)
# Apply the change to the code snippets
for span in state["code_blocks"]:
apply_change_to_snippets(change, span)
return f"Changed applied to {args.file}: {args.change_summary}"

return Action(
id="apply_code_change",
description="Applies a change to a file.",
model_cls=ApplyChangeInput,
id="edit_code",
description="Applies a change to a file. Achieves this by replacing the old code with your replacement code.",
model_cls=CodeChangeInput,
f=apply_change,
)
2 changes: 1 addition & 1 deletion src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ def run(self, inp: str, context: ProjectContext) -> RefactoringAgentState:
"code_blocks": [],
"thoughts": [],
}
config = RunnableConfig(recursion_limit=10)
config = RunnableConfig(recursion_limit=20)
return RefactoringAgentState(**self.app.invoke(state, config=config))
7 changes: 7 additions & 0 deletions src/common/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ class Definition(BaseModel):
span: CodeSpan = Field(description="The span of the definition")


class CodeChange(BaseModel):
file: str = Field(description="The file to apply the change to.")
start_line: int = Field(description="The start line of the change.")
end_line: int = Field(description="The end line of the change exclusive.")
replacement_code: str = Field(description="The replacement code to apply.")


##########################################
# State Defs

Expand Down
10 changes: 5 additions & 5 deletions src/planning/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class NewThought(BaseModel):

def thought(state: RefactoringAgentState, args: NewThought):
state["thoughts"].append(args.thought)
return 'Say "Done"'

action = Action(
id="add_thought",
Expand Down Expand Up @@ -93,8 +94,7 @@ class ShouldContinue:
def __init__(self) -> None:
should_continue_action = self._create_should_continue_action()
task = """Decide whether to think & execute again or finish. """
additional_instructions = """
Call the `should_continue` function with a true boolean to continue thinking & executing, and false to finish. Say 'Done' after you have added your thought.."""
additional_instructions = """Call the `should_continue` function with a true boolean to continue thinking & executing, and false to finish. Say 'Done' after you have called `should_continue`. Call `should_continue` only once."""
self.controller = LLMController(
[should_continue_action],
task,
Expand All @@ -104,16 +104,16 @@ def __init__(self) -> None:

def _create_should_continue_action(self):
def should_continue(state: RefactoringAgentState, args: NextStepInput):
message = f"You said should_continue={args.should_continue}. Wait for further instructions and do not invoke any functions including `should_continue`. Simply say 'Done'"
if args.should_continue:
self.next_node = "think"
return "Wait for further instructions."
else:
self.next_node = "finish"
return "Wait for further instructions."
return message

action = Action(
id="should_continue",
description="""true = think and execute, false = finish""",
description="""true = think and execute, false = finish.""",
model_cls=NextStepInput,
# return_direct=True,
f=should_continue,
Expand Down
6 changes: 5 additions & 1 deletion src/planning/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def record_to_str(record: ActionRecord) -> str:
# Get the name of the type of record.result
type_name = record["result"].__class__.__name__
# check if request.result is a derived class of BaseModel
result_str = f"{type_name}({record['result']})"
# if type_name is str or int, then result_str is the value of record.result
if type_name == "str" or type_name == "int":
result_str = record["result"]
else:
result_str = f"{type_name}({record['result']})"

return f"{{\"request\":{request_to_str(record['request'])},\"result\":\"{result_str}\"}}"

Expand Down
6 changes: 3 additions & 3 deletions src/utilities/jedi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def jedi_name_to_symbol(name, context: ProjectContext) -> Symbol:

def load_code(span: CodeSpan, context: ProjectContext):
path = add_path_to_prefix(context.folder_path, span.file_path)
start = span.start_line
end = span.end_line
start = span.start_line - 1
end = span.end_line - 1
with open(path, "r") as file:
code = file.readlines()
# Get the code
return "".join(code[start - 1 : end + 1])
return "".join(code[start : end + 1])


def add_line_numbers(code: str, starting_line: int) -> str:
Expand Down
14 changes: 11 additions & 3 deletions src/utilities/paths.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import os

# TODO
# class RelativeFilePath:


def standardize_code_path(path: str) -> str:
path = path.replace("\\", "/")
if path.startswith("/"):
return path[1:]
return path


def remove_path_prefix(path: str, prefix: str) -> str:
# turn both into absolute paths
Expand All @@ -12,6 +22,4 @@ def remove_path_prefix(path: str, prefix: str) -> str:


def add_path_to_prefix(prefix: str, path: str):
if path.startswith("/") or path.startswith("\\"):
return os.path.join(prefix, path[1:])
return os.path.join(prefix, path)
return os.path.join(prefix, standardize_code_path(path))

0 comments on commit 320c3ba

Please sign in to comment.