-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
260 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ langchain | |
langchainhub | ||
langchain_openai | ||
langchain_experimental | ||
langgraph | ||
|
||
python-dotenv | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
from .python_repl import PythonReplTool | ||
from .code_search import CodeSearchToolkit | ||
from .project_context import ProjectContext |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
from sre_constants import SUCCESS | ||
from typing import TypeVar, Generic, Callable, Type, TypedDict | ||
from ..common import ProjectContext | ||
from pydantic import BaseModel | ||
from langchain_core.output_parsers import JsonOutputParser | ||
|
||
# Action is an abstract base class | ||
|
||
|
||
class Action(ABC): | ||
def __init__(self, id, description): | ||
self.id = id | ||
self.description = description | ||
|
||
@abstractmethod | ||
def execute(self, action_str: str, **kwargs) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
def get_prompt_schema(self) -> str: | ||
pass | ||
|
||
|
||
class ActionSuccess(Enum): | ||
SUCCESS = "SUCCESS" | ||
ACTION_NOT_FOUND = "ACTION_NOT_FOUND" | ||
ACTION_FAILED = "ACTION_FAILED" | ||
|
||
|
||
class ActionRecord(TypedDict): | ||
id: str | ||
success: ActionSuccess | ||
action_str: str | ||
observation: str | ||
|
||
|
||
AA = TypeVar("AA", bound=BaseModel) | ||
# Ensure AA is a pedantic type | ||
|
||
|
||
class ProjectContextualisedAction(Action, Generic[AA]): | ||
def __init__( | ||
self, | ||
id, | ||
description, | ||
model_cls: Type[AA], | ||
f: Callable[[ProjectContext, AA], str], | ||
): | ||
super().__init__(id, description) | ||
self.parser = JsonOutputParser(pydantic_object=model_cls) | ||
self.f = f | ||
self.cls = model_cls | ||
|
||
def execute(self, action_str: str, **kwargs) -> str: | ||
context = kwargs.get("context") | ||
if not context or not isinstance(context, ProjectContext): | ||
raise ValueError("No context provided") | ||
args = self.parser.invoke(action_str) | ||
return self.f(context, args) | ||
|
||
def get_prompt_schema(self) -> str: | ||
return str(self.cls.model_json_schema()) | ||
|
||
|
||
class ActionDispatcher: | ||
def __init__(self): | ||
self.actions = {} | ||
|
||
def register_action(self, action: Action): | ||
self.actions[action.id] = action | ||
|
||
def dispatch(self, id: str, action_str: str, **kwargs) -> ActionRecord: | ||
action = self.actions.get(id) | ||
if action: | ||
try: | ||
observation = action.execute(action_str, **kwargs) | ||
return ActionRecord( | ||
id=id, | ||
success=ActionSuccess.SUCCESS, | ||
action_str=action_str, | ||
observation=observation, | ||
) | ||
except Exception as e: | ||
return ActionRecord( | ||
id=id, | ||
success=ActionSuccess.ACTION_FAILED, | ||
action_str=action_str, | ||
observation=str(e), | ||
) | ||
else: | ||
return ActionRecord( | ||
id=id, | ||
success=ActionSuccess.ACTION_NOT_FOUND, | ||
action_str=action_str, | ||
observation="", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from action import ActionDispatcher, ProjectContextualisedAction | ||
from pydantic import BaseModel, Field | ||
|
||
######################### | ||
# Logging Action | ||
|
||
|
||
class LoggingInput(BaseModel): | ||
message: str = Field(description="The message to log") | ||
|
||
|
||
def create_logging_action(): | ||
def log(context, args): | ||
print(args.message) | ||
return "Logged message" | ||
|
||
action = ProjectContextualisedAction( | ||
id="log", description="Log a message", model_cls=LoggingInput, f=log | ||
) | ||
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Type, Optional, List | ||
from ..common import ProjectContext, Symbol, parse_completion_to_symbol | ||
|
||
from langchain.pydantic_v1 import BaseModel, Field | ||
from langchain.tools import BaseTool, StructuredTool, tool | ||
from langchain_core.runnables import RunnableBinding | ||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForToolRun, | ||
CallbackManagerForToolRun, | ||
) | ||
import jedi | ||
import os | ||
|
||
|
||
###################################### | ||
# JEDI Utils | ||
def get_definition_for_name(file_path, start_line, end_line): | ||
# Load the file | ||
with open(file_path, "r") as file: | ||
code = file.readlines() | ||
# Get the code | ||
return "\n".join(code[start_line:end_line]) | ||
|
||
|
||
########################################### | ||
# Tools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,60 @@ | ||
from typing import Type, Optional | ||
|
||
from .project_context import ProjectContext | ||
from typing import Optional, List | ||
from ..common import ProjectContext, Symbol, parse_completion_to_symbol | ||
|
||
from langchain.pydantic_v1 import BaseModel, Field | ||
from langchain.tools import BaseTool, StructuredTool, tool | ||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForToolRun, | ||
CallbackManagerForToolRun, | ||
) | ||
from langchain.tools import tool | ||
|
||
import jedi | ||
import os | ||
|
||
###################################### | ||
# JEDI Utils | ||
def get_definition_for_name(file_path, start_line,end_line): | ||
# Load the file | ||
with open(file_path, "r") as file: | ||
code = file.readlines() | ||
# Get the code | ||
return "\n".join(code[start_line:end_line]) | ||
|
||
|
||
########################################### | ||
# Tools | ||
class SearchInput(BaseModel): | ||
query: str = Field(description="should be a search query for jedi") | ||
file_path: str = Field(description="should be a file path") | ||
|
||
|
||
def create_script_search_tool(project_context: ProjectContext): | ||
def search(query:str, file_path:str): | ||
# folder path / file path | ||
path = os.path.join(project_context["folder_path"], file_path) | ||
script = jedi.Script(path=path) | ||
completions = list(script.complete_search(query)) | ||
print(completions) | ||
result = [] | ||
for completion in completions: | ||
signatures = completion.get_signatures() | ||
start = completion.get_definition_start_position() | ||
end = completion.get_definition_end_position() | ||
body = get_definition_for_name(path,start[0],end[0]) | ||
info = { | ||
"name": completion.name, | ||
"type": completion.type, | ||
"signatures": [sig.description for sig in signatures], | ||
"body": body, | ||
} | ||
result.append(str(info)) | ||
|
||
|
||
|
||
return "\n".join(result) | ||
|
||
return StructuredTool.from_function( | ||
func=search, | ||
name="script-search", | ||
description="Search a file for symbols using jedi. Returns the definition of the symbol.", | ||
args_schema=SearchInput, | ||
class SearchInput(BaseModel): | ||
query: str = Field(description="a symbol to search for in repository.") | ||
fuzzy: bool = Field(description="whether to use fuzzy search", default=False) | ||
file_path: Optional[str] = Field( | ||
description="whether to narrow the search to a specific file. If not provided, search the entire repository.", | ||
default=None, | ||
) | ||
|
||
|
||
|
||
def create_code_search(context: ProjectContext): | ||
@tool("code-symbol-search", args_schema=SearchInput) | ||
def code_search( | ||
query: str, | ||
fuzzy: bool = False, | ||
file_path: Optional[str] = None, | ||
) -> List[Symbol]: | ||
""" | ||
Performs a search for a symbol in a file or folder. | ||
""" | ||
# searcher | ||
if file_path is None: | ||
# folder path | ||
searcher = jedi.Project(context.folder_path) | ||
else: | ||
# folder path / file path | ||
path = os.path.join(context.folder_path, file_path) | ||
searcher = jedi.Script(path=path) | ||
|
||
completions = list(searcher.complete_search(query, fuzzy=fuzzy)) | ||
|
||
print(completions) | ||
return [parse_completion_to_symbol(completion) for completion in completions] | ||
|
||
return code_search | ||
|
||
|
||
# Create a search toolkita | ||
|
||
|
||
class CodeSearchToolkit: | ||
def __init__(self,context:ProjectContext) -> None: | ||
|
||
def __init__(self, context: ProjectContext): | ||
self.context = context | ||
|
||
def get_tools(self): | ||
return [create_script_search_tool(self.context)] | ||
return [create_code_search(self.context)] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,39 @@ | ||
from langchain import hub | ||
from langchain.agents import AgentExecutor, create_openai_functions_agent | ||
from langchain_openai import ChatOpenAI | ||
from .actions import PythonReplTool, CodeSearchToolkit, ProjectContext | ||
from langgraph.graph import StateGraph | ||
from src.planning.memory import History | ||
from .actions import PythonReplTool, CodeSearchToolkit | ||
from .common.definitions import ProjectContext | ||
from .actions.action import ActionDispatcher | ||
from .actions.basic_actions import create_logging_action | ||
from typing import TypedDict | ||
|
||
|
||
def test_agent(query: str,repo_path: str): | ||
context:ProjectContext = { | ||
"folder_path": repo_path | ||
} | ||
tools = [PythonReplTool, *CodeSearchToolkit(context).get_tools()] | ||
# Get the prompt to use - you can modify this! | ||
prompt = hub.pull("hwchase17/openai-functions-agent") | ||
# Choose the LLM that will drive the agent | ||
llm = ChatOpenAI(model="gpt-4-turbo-preview") | ||
# Construct the OpenAI Functions agent | ||
agent_runnable = create_openai_functions_agent(llm, tools, prompt) | ||
class RefactoringAgentState(TypedDict): | ||
history: History | ||
|
||
# Create the agent executor | ||
agent_executor = AgentExecutor(agent=agent_runnable, tools=tools,verbose=True) | ||
|
||
# Run the agent | ||
response: str = agent_executor.invoke({"input": query}) | ||
return response | ||
class RefactoringAgent: | ||
def __init__(self, context: ProjectContext): | ||
self.context = context | ||
# Load Actions | ||
self.dispatcher = ActionDispatcher() | ||
self.dispatcher.register_action(create_logging_action()) | ||
|
||
self._setup_agent_graph() | ||
|
||
@staticmethod | ||
def _should_continue(state: RefactoringAgentState): | ||
return False | ||
|
||
def _setup_agent_graph(self): | ||
self.graph = StateGraph(RefactoringAgentState) | ||
|
||
self.llm = ChatOpenAI(model="gpt-4-turbo-preview") | ||
|
||
#self.graph.add_node('') | ||
self.app = self.graph.compile() | ||
|
||
def run(self,input:str) | ||
return self.app.invoke(input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .definitions import * |
Oops, something went wrong.