Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
A-F-V committed Feb 18, 2024
1 parent c713b7a commit bc98808
Show file tree
Hide file tree
Showing 13 changed files with 260 additions and 81 deletions.
5 changes: 3 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"python.analysis.typeCheckingMode": "basic",
"[python]": {
// Style Formatting
"editor.defaultFormatter": "ms-python.python",
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
},
// Pylint Settings
Expand All @@ -24,5 +24,6 @@
"python.testing.pytestEnabled": true,
"python.analysis.diagnosticSeverityOverrides": {
"reportUnknownMemberType": "none"
}
},
"python.analysis.autoImportCompletions": true
}
8 changes: 6 additions & 2 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import os
import dotenv
import pygit2 as git
from src.agent import test_agent
from src.agent import RefactoringAgent, test_agent
from src.common.definitions import ProjectContext

dotenv.load_dotenv()

Expand Down Expand Up @@ -40,7 +41,10 @@ def run(repo: str):

code_path = "edit_distance/edit_distance.py"
query = f"Get the docstring of a function starting with `lowest` in {code_path}. Return only that"
click.echo(test_agent(query,repo))

context = ProjectContext(folder_path=repo)
agent = RefactoringAgent(context)
click.echo(agent.run(query)


cli.add_command(init_repo)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ langchain
langchainhub
langchain_openai
langchain_experimental
langgraph

python-dotenv

Expand Down
1 change: 0 additions & 1 deletion src/actions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .python_repl import PythonReplTool
from .code_search import CodeSearchToolkit
from .project_context import ProjectContext
98 changes: 98 additions & 0 deletions src/actions/action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from abc import ABC, abstractmethod
from enum import Enum
from sre_constants import SUCCESS
from typing import TypeVar, Generic, Callable, Type, TypedDict
from ..common import ProjectContext
from pydantic import BaseModel
from langchain_core.output_parsers import JsonOutputParser

# Action is an abstract base class


class Action(ABC):
def __init__(self, id, description):
self.id = id
self.description = description

@abstractmethod
def execute(self, action_str: str, **kwargs) -> str:
pass

@abstractmethod
def get_prompt_schema(self) -> str:
pass


class ActionSuccess(Enum):
SUCCESS = "SUCCESS"
ACTION_NOT_FOUND = "ACTION_NOT_FOUND"
ACTION_FAILED = "ACTION_FAILED"


class ActionRecord(TypedDict):
id: str
success: ActionSuccess
action_str: str
observation: str


AA = TypeVar("AA", bound=BaseModel)
# Ensure AA is a pedantic type


class ProjectContextualisedAction(Action, Generic[AA]):
def __init__(
self,
id,
description,
model_cls: Type[AA],
f: Callable[[ProjectContext, AA], str],
):
super().__init__(id, description)
self.parser = JsonOutputParser(pydantic_object=model_cls)
self.f = f
self.cls = model_cls

def execute(self, action_str: str, **kwargs) -> str:
context = kwargs.get("context")
if not context or not isinstance(context, ProjectContext):
raise ValueError("No context provided")
args = self.parser.invoke(action_str)
return self.f(context, args)

def get_prompt_schema(self) -> str:
return str(self.cls.model_json_schema())


class ActionDispatcher:
def __init__(self):
self.actions = {}

def register_action(self, action: Action):
self.actions[action.id] = action

def dispatch(self, id: str, action_str: str, **kwargs) -> ActionRecord:
action = self.actions.get(id)
if action:
try:
observation = action.execute(action_str, **kwargs)
return ActionRecord(
id=id,
success=ActionSuccess.SUCCESS,
action_str=action_str,
observation=observation,
)
except Exception as e:
return ActionRecord(
id=id,
success=ActionSuccess.ACTION_FAILED,
action_str=action_str,
observation=str(e),
)
else:
return ActionRecord(
id=id,
success=ActionSuccess.ACTION_NOT_FOUND,
action_str=action_str,
observation="",
)
20 changes: 20 additions & 0 deletions src/actions/basic_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from action import ActionDispatcher, ProjectContextualisedAction
from pydantic import BaseModel, Field

#########################
# Logging Action


class LoggingInput(BaseModel):
message: str = Field(description="The message to log")


def create_logging_action():
def log(context, args):
print(args.message)
return "Logged message"

action = ProjectContextualisedAction(
id="log", description="Log a message", model_cls=LoggingInput, f=log
)
return action
26 changes: 26 additions & 0 deletions src/actions/code_inspection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Type, Optional, List
from ..common import ProjectContext, Symbol, parse_completion_to_symbol

from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, StructuredTool, tool
from langchain_core.runnables import RunnableBinding
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
import jedi
import os


######################################
# JEDI Utils
def get_definition_for_name(file_path, start_line, end_line):
# Load the file
with open(file_path, "r") as file:
code = file.readlines()
# Get the code
return "\n".join(code[start_line:end_line])


###########################################
# Tools
93 changes: 39 additions & 54 deletions src/actions/code_search.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,60 @@
from typing import Type, Optional

from .project_context import ProjectContext
from typing import Optional, List
from ..common import ProjectContext, Symbol, parse_completion_to_symbol

from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, StructuredTool, tool
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools import tool

import jedi
import os

######################################
# JEDI Utils
def get_definition_for_name(file_path, start_line,end_line):
# Load the file
with open(file_path, "r") as file:
code = file.readlines()
# Get the code
return "\n".join(code[start_line:end_line])


###########################################
# Tools
class SearchInput(BaseModel):
query: str = Field(description="should be a search query for jedi")
file_path: str = Field(description="should be a file path")


def create_script_search_tool(project_context: ProjectContext):
def search(query:str, file_path:str):
# folder path / file path
path = os.path.join(project_context["folder_path"], file_path)
script = jedi.Script(path=path)
completions = list(script.complete_search(query))
print(completions)
result = []
for completion in completions:
signatures = completion.get_signatures()
start = completion.get_definition_start_position()
end = completion.get_definition_end_position()
body = get_definition_for_name(path,start[0],end[0])
info = {
"name": completion.name,
"type": completion.type,
"signatures": [sig.description for sig in signatures],
"body": body,
}
result.append(str(info))



return "\n".join(result)

return StructuredTool.from_function(
func=search,
name="script-search",
description="Search a file for symbols using jedi. Returns the definition of the symbol.",
args_schema=SearchInput,
class SearchInput(BaseModel):
query: str = Field(description="a symbol to search for in repository.")
fuzzy: bool = Field(description="whether to use fuzzy search", default=False)
file_path: Optional[str] = Field(
description="whether to narrow the search to a specific file. If not provided, search the entire repository.",
default=None,
)



def create_code_search(context: ProjectContext):
@tool("code-symbol-search", args_schema=SearchInput)
def code_search(
query: str,
fuzzy: bool = False,
file_path: Optional[str] = None,
) -> List[Symbol]:
"""
Performs a search for a symbol in a file or folder.
"""
# searcher
if file_path is None:
# folder path
searcher = jedi.Project(context.folder_path)
else:
# folder path / file path
path = os.path.join(context.folder_path, file_path)
searcher = jedi.Script(path=path)

completions = list(searcher.complete_search(query, fuzzy=fuzzy))

print(completions)
return [parse_completion_to_symbol(completion) for completion in completions]

return code_search


# Create a search toolkita


class CodeSearchToolkit:
def __init__(self,context:ProjectContext) -> None:

def __init__(self, context: ProjectContext):
self.context = context

def get_tools(self):
return [create_script_search_tool(self.context)]
return [create_code_search(self.context)]
5 changes: 0 additions & 5 deletions src/actions/project_context.py

This file was deleted.

49 changes: 32 additions & 17 deletions src/agent.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
from langchain import hub
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_openai import ChatOpenAI
from .actions import PythonReplTool, CodeSearchToolkit, ProjectContext
from langgraph.graph import StateGraph
from src.planning.memory import History
from .actions import PythonReplTool, CodeSearchToolkit
from .common.definitions import ProjectContext
from .actions.action import ActionDispatcher
from .actions.basic_actions import create_logging_action
from typing import TypedDict


def test_agent(query: str,repo_path: str):
context:ProjectContext = {
"folder_path": repo_path
}
tools = [PythonReplTool, *CodeSearchToolkit(context).get_tools()]
# Get the prompt to use - you can modify this!
prompt = hub.pull("hwchase17/openai-functions-agent")
# Choose the LLM that will drive the agent
llm = ChatOpenAI(model="gpt-4-turbo-preview")
# Construct the OpenAI Functions agent
agent_runnable = create_openai_functions_agent(llm, tools, prompt)
class RefactoringAgentState(TypedDict):
history: History

# Create the agent executor
agent_executor = AgentExecutor(agent=agent_runnable, tools=tools,verbose=True)

# Run the agent
response: str = agent_executor.invoke({"input": query})
return response
class RefactoringAgent:
def __init__(self, context: ProjectContext):
self.context = context
# Load Actions
self.dispatcher = ActionDispatcher()
self.dispatcher.register_action(create_logging_action())

self._setup_agent_graph()

@staticmethod
def _should_continue(state: RefactoringAgentState):
return False

def _setup_agent_graph(self):
self.graph = StateGraph(RefactoringAgentState)

self.llm = ChatOpenAI(model="gpt-4-turbo-preview")

#self.graph.add_node('')
self.app = self.graph.compile()

def run(self,input:str)
return self.app.invoke(input)
1 change: 1 addition & 0 deletions src/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .definitions import *
Loading

0 comments on commit bc98808

Please sign in to comment.