Skip to content

Commit

Permalink
Search Tool start
Browse files Browse the repository at this point in the history
  • Loading branch information
A-F-V committed Feb 17, 2024
1 parent d8d48f9 commit c713b7a
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 9 deletions.
8 changes: 6 additions & 2 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ def init_repo(repo: str, folder: str):


@click.command()
def run():
@click.option("--repo", default="./demo", help="The repo to download")
def run(repo: str):

click.echo("Running the demo")
click.echo(test_agent("What is the square root of 110"))

code_path = "edit_distance/edit_distance.py"
query = f"Get the docstring of a function starting with `lowest` in {code_path}. Return only that"
click.echo(test_agent(query,repo))


cli.add_command(init_repo)
Expand Down
4 changes: 3 additions & 1 deletion src/actions/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .python_repl import PythonReplTool
from .python_repl import PythonReplTool
from .code_search import CodeSearchToolkit
from .project_context import ProjectContext
75 changes: 75 additions & 0 deletions src/actions/code_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Type, Optional

from .project_context import ProjectContext

from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, StructuredTool, tool
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
import jedi
import os

######################################
# JEDI Utils
def get_definition_for_name(file_path, start_line,end_line):
# Load the file
with open(file_path, "r") as file:
code = file.readlines()
# Get the code
return "\n".join(code[start_line:end_line])


###########################################
# Tools
class SearchInput(BaseModel):
query: str = Field(description="should be a search query for jedi")
file_path: str = Field(description="should be a file path")


def create_script_search_tool(project_context: ProjectContext):
def search(query:str, file_path:str):
# folder path / file path
path = os.path.join(project_context["folder_path"], file_path)
script = jedi.Script(path=path)
completions = list(script.complete_search(query))
print(completions)
result = []
for completion in completions:
signatures = completion.get_signatures()
start = completion.get_definition_start_position()
end = completion.get_definition_end_position()
body = get_definition_for_name(path,start[0],end[0])
info = {
"name": completion.name,
"type": completion.type,
"signatures": [sig.description for sig in signatures],
"body": body,
}
result.append(str(info))



return "\n".join(result)

return StructuredTool.from_function(
func=search,
name="script-search",
description="Search a file for symbols using jedi. Returns the definition of the symbol.",
args_schema=SearchInput,
)






# Create a search toolkita

class CodeSearchToolkit:
def __init__(self,context:ProjectContext) -> None:
self.context = context

def get_tools(self):
return [create_script_search_tool(self.context)]
5 changes: 5 additions & 0 deletions src/actions/project_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import TypedDict

class ProjectContext(TypedDict):
"""A project context."""
folder_path: str
13 changes: 7 additions & 6 deletions src/agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from langchain import hub
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_openai import ChatOpenAI
from .actions import PythonReplTool
from .actions import PythonReplTool, CodeSearchToolkit, ProjectContext



def test_agent(input: str):

tools = [PythonReplTool]
def test_agent(query: str,repo_path: str):
context:ProjectContext = {
"folder_path": repo_path
}
tools = [PythonReplTool, *CodeSearchToolkit(context).get_tools()]
# Get the prompt to use - you can modify this!
prompt = hub.pull("hwchase17/openai-functions-agent")
# Choose the LLM that will drive the agent
Expand All @@ -19,5 +20,5 @@ def test_agent(input: str):
agent_executor = AgentExecutor(agent=agent_runnable, tools=tools,verbose=True)

# Run the agent
response: str = agent_executor.invoke({"input": input})
response: str = agent_executor.invoke({"input": query})
return response

0 comments on commit c713b7a

Please sign in to comment.