Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep: Add tests for context agent #3646

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions sweepai/core/context_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,19 @@ def escape_ripgrep(text):
text = text.replace(s, "\\" + s)
return text

def run_ripgrep_command(code_entity, repo_dir):
rg_command = [
"rg",
"-n",
"-i",
code_entity,
repo_dir,
]
result = subprocess.run(
" ".join(rg_command), text=True, shell=True, capture_output=True
)
return result.stdout

@staticmethod
def can_add_snippet(snippet: Snippet, current_snippets: list[Snippet]):
return (
Expand Down Expand Up @@ -752,18 +765,8 @@ def handle_function_call(
if function_name == "code_search":
code_entity = f'"{function_input["code_entity"]}"' # handles cases with two words
code_entity = escape_ripgrep(code_entity) # escape special characters
rg_command = [
"rg",
"-n",
"-i",
code_entity,
repo_context_manager.cloned_repo.repo_dir,
]
try:
result = subprocess.run(
" ".join(rg_command), text=True, shell=True, capture_output=True
)
rg_output = result.stdout
rg_output = run_ripgrep_command(code_entity, repo_context_manager.cloned_repo.repo_dir)
if rg_output:
# post process rip grep output to be more condensed
rg_output_pretty, file_output_dict, file_to_num_occurrences = post_process_rg_output(
Expand Down
50 changes: 50 additions & 0 deletions tests/test_context_pruning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest
from sweepai.core.context_pruning import (
build_full_hierarchy,
load_graph_from_file,
RepoContextManager,
get_relevant_context,
)
import networkx as nx

class TestContextPruning(unittest.TestCase):
def test_build_full_hierarchy(self):
G = nx.DiGraph()
G.add_edge("main.py", "database.py")
G.add_edge("database.py", "models.py")
G.add_edge("utils.py", "models.py")
hierarchy = build_full_hierarchy(G, "main.py", 2)
expected_hierarchy = """main.py
├── database.py
│ └── models.py
└── utils.py
└── models.py
"""
self.assertEqual(hierarchy, expected_hierarchy)

def test_load_graph_from_file(self):
graph = load_graph_from_file("tests/test_import_tree.txt")
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.nodes), 5)
self.assertEqual(len(graph.edges), 4)

def test_get_relevant_context(self):
cloned_repo = ClonedRepo("sweepai/sweep", "123", "main")
repo_context_manager = RepoContextManager(
dir_obj=None,
current_top_tree="",
snippets=[],
snippet_scores={},
cloned_repo=cloned_repo,
)
query = "allow 'sweep.yaml' to be read from the user/organization's .github repository. this is found in client.py and we need to change this to optionally read from .github/sweep.yaml if it exists there"
rcm = get_relevant_context(
query,
repo_context_manager,
seed=42,
ticket_progress=None,
chat_logger=None,
)
self.assertIsInstance(rcm, RepoContextManager)
self.assertTrue(len(rcm.current_top_snippets) > 0)
self.assertTrue(any("client.py" in snippet.file_path for snippet in rcm.current_top_snippets))
Loading