Skip to content

Commit

Permalink
Consider simplifying
Browse files Browse the repository at this point in the history
  • Loading branch information
A-F-V committed Feb 19, 2024
1 parent a6a6d03 commit 8df0d89
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 172 deletions.
66 changes: 43 additions & 23 deletions src/actions/action.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,87 @@
from abc import ABC, abstractmethod
from ast import Str
from enum import Enum
import json
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from re import I
from sre_constants import SUCCESS
from typing import TypeVar, Generic, Callable, Type, TypedDict
from typing import Optional, TypeVar, Generic, Callable, Type, TypedDict
from unittest.mock import Base
from src.common.definitions import FailureReason, FeedbackMessage
from src.common.definitions import FailureReason

from src.planning.state import RefactoringAgentState
from src.planning.state import ActionRequest, FeedbackMessage, RefactoringAgentState
from ..common import ProjectContext
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import JsonOutputParser
from langchain.tools import tool, StructuredTool

# Action is an abstract base class

ActionArgs = TypeVar("ActionArgs", bound=BaseModel)

# State = TypeVar("State")

ActionArgs = TypeVar("ActionArgs", bound=BaseModel)
ActionReturnType = TypeVar("ActionReturnType")


class Action(Generic[ActionArgs]):
class Action(Generic[ActionArgs, ActionReturnType]):
def __init__(
self,
id,
description,
model_cls: Type[ActionArgs],
f: Callable[[RefactoringAgentState, ActionArgs], str],
f: Callable[[RefactoringAgentState, ActionArgs], ActionReturnType],
return_direct=False,
):
self.id = id
self.description = description
self.parser = JsonOutputParser(pydantic_object=model_cls)
self.f = f
self.cls = model_cls
self.return_direct = return_direct

def execute(self, state: RefactoringAgentState, action_str: str) -> str:
# TODO: error handling
action_args_kwargs = self.parser.invoke(action_str)
args = self.cls(**action_args_kwargs)
result = self.f(state, args)
return result
def execute(
self, state: RefactoringAgentState, args: ActionArgs
) -> ActionReturnType:
return self.f(state, args)

def to_tool(self, state: RefactoringAgentState) -> StructuredTool:
def tool_f(**kwargs):
class Callbacks(BaseCallbackHandler):
def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs,
):
if error is FeedbackMessage:
state["feedback"].append(error)

def tool_f(**kwargs) -> ActionReturnType:
args_str = json.dumps(kwargs)
try:
try:
args = self.cls(**kwargs)
except Exception as e:
raise FeedbackMessage(FailureReason.INVALID_ACTION_ARGS, str(e))
args = self.cls(**kwargs)
request = ActionRequest(id=self.id, args=args)
try:
return self.f(state, args)
except FeedbackMessage as f:
raise f
except Exception as e:
raise FeedbackMessage(FailureReason.ACTION_FAILED, str(e))
except FeedbackMessage as f:
state["feedback"].append(f)
raise FeedbackMessage(
FailureReason.ACTION_FAILED, str(e), request=request
)
except Exception as e:
raise FeedbackMessage(
FailureReason.INVALID_ACTION_ARGS,
f"Coulnd't parse arguments: {args_str}",
)

return StructuredTool(
name=self.id,
callbacks=[Callbacks()],
description=self.description,
args_schema=self.cls,
func=tool_f,
return_direct=self.return_direct,
)

def __str__(self):
Expand Down
8 changes: 4 additions & 4 deletions src/actions/code_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@


class ReadCodeSnippetInput(BaseModel):
code_span: dict = Field(description="The code span to read")
code_span: CodeSpan = Field(description="The code span to read")


# TODO: Error handling
def create_code_loader():
def code_reader(state: RefactoringAgentState, args: ReadCodeSnippetInput) -> str:
span = CodeSpan(**args.code_span)
span = args.code_span
snippet = span_to_snippet(span, state["project_context"])

state["code_snippets"].append(snippet)
return f"Loaded code snippet from {span.file_location} at lines {span.start_line} to {span.end_line}"
return f"Loaded code snippet from {span.file_path} at lines {span.start_line} to {span.end_line}"

return Action(
id="code_load",
description="Loads a code span from the project",
description="Loads a block of code as a snippet. The block loaded must be a CodeSpan type object that has been returned from a code search. The snippet will be stored in the state for later use.",
model_cls=ReadCodeSnippetInput,
f=code_reader,
)
19 changes: 10 additions & 9 deletions src/actions/code_search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from re import S
import stat
from typing import Optional, List
from src.actions.action import Action
Expand Down Expand Up @@ -32,7 +33,7 @@ class SearchInput(BaseModel):

def create_code_search():

def code_search(state: RefactoringAgentState, args: SearchInput) -> str:
def code_search(state: RefactoringAgentState, args: SearchInput) -> Symbol:
folder_path = state["project_context"].folder_path
query = args.query
fuzzy = args.fuzzy
Expand All @@ -52,31 +53,31 @@ def code_search(state: RefactoringAgentState, args: SearchInput) -> str:
jedi_name_to_symbol(completion, state["project_context"])
for completion in completions
]
output = "\n".join(map(pydantic_to_str, output))
return output
# output = "\n".join(map(lambda x: pydantic_to_str(x, "symbol"), output))
return output[0]

return Action(
id="code_search",
description="Performs a search for a symbol in a file or folder.",
description="Performs a search for a symbol in a file or folder. This will not likely return the definition, so you will have to ask for that explicitly",
model_cls=SearchInput,
f=code_search,
)


class GotoDefinitionInput(BaseModel):
symbol: dict = Field(description="The symbol to get the definition of.")
symbol: Symbol = Field(description="The symbol to get the definition of.")


# TODO: Error handling
def create_definition_gotoer():
def code_goto_definition(
state: RefactoringAgentState, args: GotoDefinitionInput
) -> str:
symbol = Symbol(**args.symbol)
) -> Symbol:
symbol = args.symbol

source_name = goto_symbol(state["project_context"], symbol)

assert len(source_name._name) == 1
assert len(source_name) == 1

definition_name = source_name[0].goto()

Expand All @@ -86,7 +87,7 @@ def code_goto_definition(
definition_name[0], state["project_context"]
)

return pydantic_to_str(definition_symbol)
return definition_symbol

return Action(
id="code_goto_definition",
Expand Down
10 changes: 5 additions & 5 deletions src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from src.actions.code_inspection import create_code_loader
from src.actions.code_search import create_definition_gotoer
from src.actions.code_search import create_code_search
from src.planning.planner import DecisionMaker, Planner
from src.planning.planner import ShouldContinue, Planner
from src.planning.state import RefactoringAgentState
from .common.definitions import ProjectContext
from .execution import ActionDispatcher, ExecuteTopOfPlan, LLMController
from .execution import ActionDispatcher, ExecutePlan, ExecuteTopOfPlan, LLMController
from .actions.basic_actions import create_logging_action


Expand Down Expand Up @@ -41,16 +41,16 @@ def _setup_agent_graph(self):
self.graph = StateGraph(RefactoringAgentState)

self.graph.add_node("planner", Planner(action_list))
self.graph.add_node("execute", ExecuteTopOfPlan(action_list))
self.graph.add_node("execute", ExecutePlan(action_list))
self.graph.add_node(
"finish",
LLMController(
[create_logging_action()],
"Log any results you wish to show the user by calling print_message.",
),
)
self.graph.add_conditional_edges("planner", DecisionMaker())
self.graph.add_conditional_edges("execute", DecisionMaker())
self.graph.add_edge("planner", "execute")
self.graph.add_conditional_edges("execute", ShouldContinue())
self.graph.add_edge("finish", END)
self.graph.set_entry_point("planner")
# self.graph.add_node('')
Expand Down
43 changes: 7 additions & 36 deletions src/common/definitions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from enum import Enum
from pydantic import BaseModel, Field
from typing import TypedDict
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Optional, TypedDict


def pydantic_to_str(request: BaseModel, with_name: bool = True) -> str:
def pydantic_to_str(request: BaseModel, name: str) -> str:
# get name of type
name = request.__class__.__name__ if with_name else ""
return f"{name}{request.model_dump()}"
return f"{{'{name}':{request.dict()}}}"


class ProjectContext(BaseModel):
Expand All @@ -19,13 +18,13 @@ class ProjectContext(BaseModel):
# Action Defs
class Symbol(BaseModel):
name: str = Field(description="The name of the symbol")
file_location: str = Field(description="The file location of the symbol")
file_path: str = Field(description="The file path of the symbol")
line: int = Field(description="The line number of the symbol")
column: int = Field(description="The column number of the symbol")


class CodeSpan(BaseModel):
file_location: str = Field(description="The file location of the span")
file_path: str = Field(description="The file location of the span")
start_line: int = Field(description="The starting line number of the span")
end_line: int = Field(description="The ending line number of the span")

Expand All @@ -45,25 +44,8 @@ class ActionSuccess(Enum):
ACTION_FAILED = "ACTION_FAILED"


class ActionRequest(TypedDict):
id: str
action_str: str


def request_to_str(request: ActionRequest) -> str:
return f"{{\"name\":{request['id']},\"parameters\":{request['action_str']}}}"


class ActionRecord(TypedDict):
request: ActionRequest
result: str


# can we just store the parsed action str?
# Don't need Success?
def record_to_str(record: ActionRecord) -> str:
return f"{{\"request\":{request_to_str(record['request'])},\"result\":'{record['result']}'}}"


class FailureReason(Enum):
ACTION_NOT_FOUND = "ACTION_NOT_FOUND"
INVALID_ACTION_ARGS = "INVALID_ACTION_ARGS"
Expand All @@ -72,17 +54,6 @@ class FailureReason(Enum):


# Make FeedbackMessage an Exception
class FeedbackMessage(Exception):
def __init__(self, failure_reason: FailureReason, message: str):
self.reason = failure_reason
self.message = message
super().__init__(message)


def feedback_to_str(feedback: FeedbackMessage) -> str:
return f'{{"failure-reason":{feedback.reason.value},"message":{feedback.message}}}'


class CodeSnippet(TypedDict):
file_path: str
code: str
Expand Down
Loading

0 comments on commit 8df0d89

Please sign in to comment.