From 8df0d8904b82131459273302329828865f257de4 Mon Sep 17 00:00:00 2001 From: A-F-V Date: Mon, 19 Feb 2024 23:43:32 +0000 Subject: [PATCH] Consider simplifying --- src/actions/action.py | 66 ++++++++++++++++++---------- src/actions/code_inspection.py | 8 ++-- src/actions/code_search.py | 19 ++++---- src/agent.py | 10 ++--- src/common/definitions.py | 43 +++--------------- src/execution.py | 80 ++++++++++++++-------------------- src/planning/plan_actions.py | 15 +++++-- src/planning/planner.py | 53 +++++++++++----------- src/planning/state.py | 72 ++++++++++++++++++++++++------ src/utilities/jedi_utils.py | 10 ++--- 10 files changed, 204 insertions(+), 172 deletions(-) diff --git a/src/actions/action.py b/src/actions/action.py index 77db7e4..c0b591d 100644 --- a/src/actions/action.py +++ b/src/actions/action.py @@ -1,14 +1,13 @@ -from abc import ABC, abstractmethod -from ast import Str -from enum import Enum import json +from uuid import UUID +from langchain.callbacks.base import BaseCallbackHandler from re import I from sre_constants import SUCCESS -from typing import TypeVar, Generic, Callable, Type, TypedDict +from typing import Optional, TypeVar, Generic, Callable, Type, TypedDict from unittest.mock import Base -from src.common.definitions import FailureReason, FeedbackMessage +from src.common.definitions import FailureReason -from src.planning.state import RefactoringAgentState +from src.planning.state import ActionRequest, FeedbackMessage, RefactoringAgentState from ..common import ProjectContext from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.output_parsers import JsonOutputParser @@ -16,52 +15,73 @@ # Action is an abstract base class -ActionArgs = TypeVar("ActionArgs", bound=BaseModel) + # State = TypeVar("State") +ActionArgs = TypeVar("ActionArgs", bound=BaseModel) +ActionReturnType = TypeVar("ActionReturnType") + -class Action(Generic[ActionArgs]): +class Action(Generic[ActionArgs, ActionReturnType]): def __init__( self, id, description, model_cls: Type[ActionArgs], - f: Callable[[RefactoringAgentState, ActionArgs], str], + f: Callable[[RefactoringAgentState, ActionArgs], ActionReturnType], + return_direct=False, ): self.id = id self.description = description self.parser = JsonOutputParser(pydantic_object=model_cls) self.f = f self.cls = model_cls + self.return_direct = return_direct - def execute(self, state: RefactoringAgentState, action_str: str) -> str: - # TODO: error handling - action_args_kwargs = self.parser.invoke(action_str) - args = self.cls(**action_args_kwargs) - result = self.f(state, args) - return result + def execute( + self, state: RefactoringAgentState, args: ActionArgs + ) -> ActionReturnType: + return self.f(state, args) def to_tool(self, state: RefactoringAgentState) -> StructuredTool: - def tool_f(**kwargs): + class Callbacks(BaseCallbackHandler): + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs, + ): + if error is FeedbackMessage: + state["feedback"].append(error) + + def tool_f(**kwargs) -> ActionReturnType: + args_str = json.dumps(kwargs) try: - try: - args = self.cls(**kwargs) - except Exception as e: - raise FeedbackMessage(FailureReason.INVALID_ACTION_ARGS, str(e)) + args = self.cls(**kwargs) + request = ActionRequest(id=self.id, args=args) try: return self.f(state, args) except FeedbackMessage as f: raise f except Exception as e: - raise FeedbackMessage(FailureReason.ACTION_FAILED, str(e)) - except FeedbackMessage as f: - state["feedback"].append(f) + raise FeedbackMessage( + FailureReason.ACTION_FAILED, str(e), request=request + ) + except Exception as e: + raise FeedbackMessage( + FailureReason.INVALID_ACTION_ARGS, + f"Coulnd't parse arguments: {args_str}", + ) return StructuredTool( name=self.id, + callbacks=[Callbacks()], description=self.description, args_schema=self.cls, func=tool_f, + return_direct=self.return_direct, ) def __str__(self): diff --git a/src/actions/code_inspection.py b/src/actions/code_inspection.py index dc0f05a..00b3647 100644 --- a/src/actions/code_inspection.py +++ b/src/actions/code_inspection.py @@ -17,21 +17,21 @@ class ReadCodeSnippetInput(BaseModel): - code_span: dict = Field(description="The code span to read") + code_span: CodeSpan = Field(description="The code span to read") # TODO: Error handling def create_code_loader(): def code_reader(state: RefactoringAgentState, args: ReadCodeSnippetInput) -> str: - span = CodeSpan(**args.code_span) + span = args.code_span snippet = span_to_snippet(span, state["project_context"]) state["code_snippets"].append(snippet) - return f"Loaded code snippet from {span.file_location} at lines {span.start_line} to {span.end_line}" + return f"Loaded code snippet from {span.file_path} at lines {span.start_line} to {span.end_line}" return Action( id="code_load", - description="Loads a code span from the project", + description="Loads a block of code as a snippet. The block loaded must be a CodeSpan type object that has been returned from a code search. The snippet will be stored in the state for later use.", model_cls=ReadCodeSnippetInput, f=code_reader, ) diff --git a/src/actions/code_search.py b/src/actions/code_search.py index fe11bfd..33c70f7 100644 --- a/src/actions/code_search.py +++ b/src/actions/code_search.py @@ -1,4 +1,5 @@ import json +from re import S import stat from typing import Optional, List from src.actions.action import Action @@ -32,7 +33,7 @@ class SearchInput(BaseModel): def create_code_search(): - def code_search(state: RefactoringAgentState, args: SearchInput) -> str: + def code_search(state: RefactoringAgentState, args: SearchInput) -> Symbol: folder_path = state["project_context"].folder_path query = args.query fuzzy = args.fuzzy @@ -52,31 +53,31 @@ def code_search(state: RefactoringAgentState, args: SearchInput) -> str: jedi_name_to_symbol(completion, state["project_context"]) for completion in completions ] - output = "\n".join(map(pydantic_to_str, output)) - return output + # output = "\n".join(map(lambda x: pydantic_to_str(x, "symbol"), output)) + return output[0] return Action( id="code_search", - description="Performs a search for a symbol in a file or folder.", + description="Performs a search for a symbol in a file or folder. This will not likely return the definition, so you will have to ask for that explicitly", model_cls=SearchInput, f=code_search, ) class GotoDefinitionInput(BaseModel): - symbol: dict = Field(description="The symbol to get the definition of.") + symbol: Symbol = Field(description="The symbol to get the definition of.") # TODO: Error handling def create_definition_gotoer(): def code_goto_definition( state: RefactoringAgentState, args: GotoDefinitionInput - ) -> str: - symbol = Symbol(**args.symbol) + ) -> Symbol: + symbol = args.symbol source_name = goto_symbol(state["project_context"], symbol) - assert len(source_name._name) == 1 + assert len(source_name) == 1 definition_name = source_name[0].goto() @@ -86,7 +87,7 @@ def code_goto_definition( definition_name[0], state["project_context"] ) - return pydantic_to_str(definition_symbol) + return definition_symbol return Action( id="code_goto_definition", diff --git a/src/agent.py b/src/agent.py index 25310a4..6eae9ff 100644 --- a/src/agent.py +++ b/src/agent.py @@ -5,10 +5,10 @@ from src.actions.code_inspection import create_code_loader from src.actions.code_search import create_definition_gotoer from src.actions.code_search import create_code_search -from src.planning.planner import DecisionMaker, Planner +from src.planning.planner import ShouldContinue, Planner from src.planning.state import RefactoringAgentState from .common.definitions import ProjectContext -from .execution import ActionDispatcher, ExecuteTopOfPlan, LLMController +from .execution import ActionDispatcher, ExecutePlan, ExecuteTopOfPlan, LLMController from .actions.basic_actions import create_logging_action @@ -41,7 +41,7 @@ def _setup_agent_graph(self): self.graph = StateGraph(RefactoringAgentState) self.graph.add_node("planner", Planner(action_list)) - self.graph.add_node("execute", ExecuteTopOfPlan(action_list)) + self.graph.add_node("execute", ExecutePlan(action_list)) self.graph.add_node( "finish", LLMController( @@ -49,8 +49,8 @@ def _setup_agent_graph(self): "Log any results you wish to show the user by calling print_message.", ), ) - self.graph.add_conditional_edges("planner", DecisionMaker()) - self.graph.add_conditional_edges("execute", DecisionMaker()) + self.graph.add_edge("planner", "execute") + self.graph.add_conditional_edges("execute", ShouldContinue()) self.graph.add_edge("finish", END) self.graph.set_entry_point("planner") # self.graph.add_node('') diff --git a/src/common/definitions.py b/src/common/definitions.py index 8a51e21..67311b4 100644 --- a/src/common/definitions.py +++ b/src/common/definitions.py @@ -1,12 +1,11 @@ from enum import Enum -from pydantic import BaseModel, Field -from typing import TypedDict +from langchain_core.pydantic_v1 import BaseModel, Field +from typing import Optional, TypedDict -def pydantic_to_str(request: BaseModel, with_name: bool = True) -> str: +def pydantic_to_str(request: BaseModel, name: str) -> str: # get name of type - name = request.__class__.__name__ if with_name else "" - return f"{name}{request.model_dump()}" + return f"{{'{name}':{request.dict()}}}" class ProjectContext(BaseModel): @@ -19,13 +18,13 @@ class ProjectContext(BaseModel): # Action Defs class Symbol(BaseModel): name: str = Field(description="The name of the symbol") - file_location: str = Field(description="The file location of the symbol") + file_path: str = Field(description="The file path of the symbol") line: int = Field(description="The line number of the symbol") column: int = Field(description="The column number of the symbol") class CodeSpan(BaseModel): - file_location: str = Field(description="The file location of the span") + file_path: str = Field(description="The file location of the span") start_line: int = Field(description="The starting line number of the span") end_line: int = Field(description="The ending line number of the span") @@ -45,25 +44,8 @@ class ActionSuccess(Enum): ACTION_FAILED = "ACTION_FAILED" -class ActionRequest(TypedDict): - id: str - action_str: str - - -def request_to_str(request: ActionRequest) -> str: - return f"{{\"name\":{request['id']},\"parameters\":{request['action_str']}}}" - - -class ActionRecord(TypedDict): - request: ActionRequest - result: str - - +# can we just store the parsed action str? # Don't need Success? -def record_to_str(record: ActionRecord) -> str: - return f"{{\"request\":{request_to_str(record['request'])},\"result\":'{record['result']}'}}" - - class FailureReason(Enum): ACTION_NOT_FOUND = "ACTION_NOT_FOUND" INVALID_ACTION_ARGS = "INVALID_ACTION_ARGS" @@ -72,17 +54,6 @@ class FailureReason(Enum): # Make FeedbackMessage an Exception -class FeedbackMessage(Exception): - def __init__(self, failure_reason: FailureReason, message: str): - self.reason = failure_reason - self.message = message - super().__init__(message) - - -def feedback_to_str(feedback: FeedbackMessage) -> str: - return f'{{"failure-reason":{feedback.reason.value},"message":{feedback.message}}}' - - class CodeSnippet(TypedDict): file_path: str code: str diff --git a/src/execution.py b/src/execution.py index 698885a..1a19a7c 100644 --- a/src/execution.py +++ b/src/execution.py @@ -1,11 +1,12 @@ import json from nis import cat +import re import sys from tabnanny import verbose from typing import Dict, List from langchain_openai import ChatOpenAI from langchain_core.utils.function_calling import convert_to_openai_function -from src.actions.action import Action + from langchain.prompts import ( PromptTemplate, ChatPromptTemplate, @@ -17,18 +18,19 @@ from langchain_core.output_parsers import JsonOutputParser from langchain import hub from langchain.agents import AgentExecutor, create_openai_tools_agent +from src.actions.action import Action from src.common.definitions import ( ActionSuccess, - ActionRequest, - ActionRecord, FailureReason, - FeedbackMessage, - feedback_to_str, - record_to_str, - request_to_str, ) -from src.planning.state import RefactoringAgentState +from src.planning.state import ( + ActionRecord, + ActionRequest, + FeedbackMessage, + RefactoringAgentState, + state_to_str, +) from src.utilities.formatting import format_list @@ -73,20 +75,24 @@ def dispatch( ActionRecord: The result of the action execution. """ id = request["id"] - action_str = request["action_str"] + args = request["args"] action = self.actions.get(id) if action: try: - observation = action.execute(state, action_str) + observation = action.execute(state, args) return ActionRecord( request=request, result=observation, ) except Exception as e: - raise FeedbackMessage(FailureReason.ACTION_FAILED, str(e)) + raise FeedbackMessage( + FailureReason.ACTION_FAILED, str(e), request=request + ) else: raise FeedbackMessage( - FailureReason.ACTION_NOT_FOUND, f"Action {id} not found" + FailureReason.ACTION_NOT_FOUND, + f"Action {id} not found", + request=request, ) @@ -114,6 +120,16 @@ def __call__(self, state: RefactoringAgentState): return state +class ExecutePlan: + def __init__(self, action_list: List[Action]) -> None: + self.executor = ExecuteTopOfPlan(action_list) + + def __call__(self, state: RefactoringAgentState): + while len(state["plan"]) > 0: + state = self.executor(state) + return state + + # Given a prompt, the LLMControler will dispatch a suitable action @@ -140,51 +156,21 @@ def create_prompt(self): self.agent_prompt = prompt # For Context # TODO: Evaluate this part - message = f""" + message = f""" + '{self.current_task}' --- - -'{{goal}}' ---- - -{{history}} ---- - -{{plan}} ---- - -{{feedback}} ---- - -{{console}} ---- - -{{code_snippets}} +{{state}} --- Now invoke suitable functions to complete the Current Task. +Arguments for the functions should be constructed from the context provided, including from the output of past actions. Do not send other messages other than invoking functions. Invoke no more than {self.number_of_actions} function calls to complete the task. """ self.context_prompt = PromptTemplate.from_template(message) def format_context_prompt(self, state: RefactoringAgentState) -> str: - history = map(record_to_str, state["history"]) - plan = map(request_to_str, state["plan"]) - feedback = map(feedback_to_str, state["feedback"]) - - plan_str = format_list(plan, "P", "Plan") - history_str = format_list(history, "H", "History") - feedback_str = format_list(feedback, "F", "Feedback") - console_str = format_list(state["console"], "C", "Console") - code_str = format_list(state["code_snippets"], "S", "Code Snippets") - message_sent = self.context_prompt.format( - goal=state["goal"], - history=history_str, - plan=plan_str, - feedback=feedback_str, - console=console_str, - code_snippets=code_str, - ) + message_sent = self.context_prompt.format(state=state_to_str(state)) if self.verbose: print(message_sent) pass diff --git a/src/planning/plan_actions.py b/src/planning/plan_actions.py index a0f76e0..344f90f 100644 --- a/src/planning/plan_actions.py +++ b/src/planning/plan_actions.py @@ -1,6 +1,7 @@ import json from typing import List -from src.common.definitions import ActionRequest, FailureReason, FeedbackMessage +from src.actions.action import ActionRequest, FeedbackMessage +from src.common.definitions import FailureReason from src.execution import ActionDispatcher from src.planning.state import RefactoringAgentState from ..actions.action import Action @@ -33,9 +34,17 @@ def create_action_adder_for_plan(action: Action): def add_to_plan(state: RefactoringAgentState, args: AddToPlanInput): action_id = action.id - # Verify that the action's arguments are valid action_str = json.dumps(args.parameters) - request = ActionRequest(id=action_id, action_str=action_str) + # Verify that the action's arguments are valid + try: + p_args = action.cls(**args.parameters) + except Exception as e: + raise FeedbackMessage( + FailureReason.INVALID_ACTION_ARGS, + f"Invalid arguments for action {action_id}: {action_str}", + ) + + request = ActionRequest(id=action_id, args=p_args) state["plan"].append(request) return f"Added {action_id} with args {action_str} to plan" diff --git a/src/planning/planner.py b/src/planning/planner.py index fefcf00..1984759 100644 --- a/src/planning/planner.py +++ b/src/planning/planner.py @@ -1,7 +1,8 @@ from enum import Enum from typing import List -from src.actions.action import Action +from src.actions.action import Action, FeedbackMessage from src.actions.basic_actions import create_logging_action +from src.common.definitions import FailureReason from src.execution import ActionDispatcher, LLMController from src.planning.plan_actions import ( create_action_adder_for_plan, @@ -22,6 +23,8 @@ def __init__(self, action_list: List[Action]): task = """Select the next actions to add to the plan or clear the plan. Additional Notes: - You will be allowed to replan in the future so you can adjust your plan as you go. + - If your plan requires a result from an action that has not been executed yet, then stop planning. + - Do not let the plan fill up with garbage """ # TODO: Incorporate Saving thoughts self.controller = LLMController( @@ -38,48 +41,44 @@ class NextStep(Enum): EXECUTE = "execute" FINISH = "finish" + def __str__(self): + return self.value + class NextStepInput(BaseModel): - next_step: NextStep = Field(description="The next step to take") + should_continue: bool = Field( + description="Whether we should do another 'plan-executee' (true) loop or finish (false)." + ) -class DecisionMaker: +class ShouldContinue: next_node: str def __init__(self) -> None: - next_step_action = self._create_next_step_action() + should_continue_action = self._create_should_continue_action() task = """ - Select ONLY one of the following to do next: 'execute', 'plan' or 'finish'. - - 'execute': Run the next action on the top of 'Plan', i.e. #P1 - - 'plan': Adjust the contents of 'Plan' by adding or removing actions. - - 'finish': Ultimate goal is satisfied. No further actions are needed. - - Additional Notes: - - Call `transition_to_next_node` EXACTLY ONCE. - - You may need to adjust your plan as you go, especially if a result from a past action should be incorporated into a future action. - - A message is considered printed only if it appears under the `Console` section - - If the 'Ultimate Goal' is to have something printed or answered, then it must appear under the `Console` section. Otherwise, you will need to plan for it to be added. + Decide ONLY which branch to take next: + - Plan-Execute branch + - Finish branch + Do not return any other information """ - self.controller = LLMController([next_step_action], task) + self.controller = LLMController([should_continue_action], task) - def _create_next_step_action(self): - def transition_to_next_node(state: RefactoringAgentState, args: NextStepInput): - if args.next_step == NextStep.PLAN: + def _create_should_continue_action(self): + def should_continue(state: RefactoringAgentState, args: NextStepInput): + if args.should_continue: self.next_node = "planner" - elif args.next_step == NextStep.EXECUTE: - if state["plan"]: - self.next_node = "execute" - else: - self.next_node = "planner" # TODO: error state + return "Moving to planner step" else: self.next_node = "finish" - return f"Transitioning to next step. " + return "Moving to finish step" action = Action( - id="transition_to_next_step", - description="Transition to the next step", + id="should_continue", + description="""true = plan and execute, false = finish""", model_cls=NextStepInput, - f=transition_to_next_node, + # return_direct=True, + f=should_continue, ) return action diff --git a/src/planning/state.py b/src/planning/state.py index 557ba8c..cdd2a33 100644 --- a/src/planning/state.py +++ b/src/planning/state.py @@ -1,18 +1,58 @@ from ast import Tuple +from langchain_core.pydantic_v1 import BaseModel, Field + +from typing import Generic, List, Optional, TypeVar, TypedDict -from typing import List, TypedDict from src.common.definitions import ( - ActionRequest, CodeSnippet, - FeedbackMessage, + FailureReason, ProjectContext, - feedback_to_str, - record_to_str, - request_to_str, ) -from src.common.definitions import ActionRecord from src.utilities.formatting import format_list +ActionArgs = TypeVar("ActionArgs", bound=BaseModel) +ActionReturnType = TypeVar("ActionReturnType") + + +class ActionRequest(TypedDict, Generic[ActionArgs]): + id: str + args: ActionArgs + + +class FeedbackMessage(Exception, Generic[ActionArgs]): + def __init__( + self, + failure_reason: FailureReason, + message: str, + request: Optional[ActionRequest[ActionArgs]] = None, + ): + self.reason = failure_reason + self.message = message + self.request = request + super().__init__(message) + + +class ActionRecord(TypedDict, Generic[ActionArgs, ActionReturnType]): + request: ActionRequest[ActionArgs] + result: ActionReturnType + + +def request_to_str(request: ActionRequest) -> str: + return f"{{\"name\":{request['id']},\"parameters\":{request['args']}}}" + + +def record_to_str(record: ActionRecord) -> str: + # Get the name of the type of record.result + type_name = record["result"].__class__.__name__ + # check if request.result is a derived class of BaseModel + result_str = f"{type_name}({record['result']})" + + return f"{{\"request\":{request_to_str(record['request'])},\"result\":\"{result_str}\"}}" + + +def feedback_to_str(feedback: FeedbackMessage) -> str: + return f'{{"failure-reason":{feedback.reason.value},"message":{feedback.message},"request":{request_to_str(feedback.request) if feedback.request else "none"}}}' + class RefactoringAgentState(TypedDict): goal: str @@ -36,15 +76,21 @@ def state_to_str(state: RefactoringAgentState) -> str: console_str = format_list(state["console"], "C", "Console") code_str = format_list(state["code_snippets"], "S", "Code Snippets") return f"""Goal: -{state['goal']} -History + +'{state["goal"]}' +--- + {history_str} -Plan +--- + {plan_str} -Feedback +--- + {feedback_str} -Console +--- + {console_str} -Code +--- + {code_str} """ diff --git a/src/utilities/jedi_utils.py b/src/utilities/jedi_utils.py index 712aee5..fc969c4 100644 --- a/src/utilities/jedi_utils.py +++ b/src/utilities/jedi_utils.py @@ -10,7 +10,7 @@ def goto_symbol(context: ProjectContext, symbol: Symbol): - path = add_path_to_prefix(context.folder_path, symbol.file_location) + path = add_path_to_prefix(context.folder_path, symbol.file_path) line = symbol.line column = symbol.column jedi_script = jedi.Script(path=path) @@ -27,7 +27,7 @@ def symbol_to_definition(symbol: Symbol, context: ProjectContext) -> Definition: (end, _) = name.get_definition_end_position() span = CodeSpan( - file_location=path, + file_path=path, start_line=start, end_line=end, ) @@ -39,7 +39,7 @@ def jedi_name_to_symbol(name, context: ProjectContext) -> Symbol: path = remove_path_prefix(name.module_path, context.folder_path) result = Symbol( name=str(name.name), - file_location=str(path), + file_path=str(path), line=int(line), column=int(column), ) @@ -47,7 +47,7 @@ def jedi_name_to_symbol(name, context: ProjectContext) -> Symbol: def load_code(span: CodeSpan, context: ProjectContext): - path = add_path_to_prefix(context.folder_path, span.file_location) + path = add_path_to_prefix(context.folder_path, span.file_path) start = span.start_line end = span.end_line with open(path, "r") as file: @@ -59,7 +59,7 @@ def load_code(span: CodeSpan, context: ProjectContext): def span_to_snippet(span: CodeSpan, context: ProjectContext) -> CodeSnippet: return CodeSnippet( { - "file_path": span.file_location, + "file_path": span.file_path, "code": load_code(span, context), "starting_line": span.start_line, }