Skip to content

Commit

Permalink
config.sampler to turn on/off sampler (#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins authored Jan 19, 2024
1 parent 62ba2a6 commit 41ea6d3
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Config:
},
converter=converters.optional(converters.to_bool),
)
auto_save_snapshot: bool = attr.field(
sampler: bool = attr.field(
default=False,
metadata={
"description": (
Expand Down
3 changes: 1 addition & 2 deletions mentat/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


class Sampler:
active: bool = True
diff_active: str | None = None
commit_active: str | None = None
last_sample_id: str | None = None
Expand All @@ -43,7 +42,7 @@ def set_active_diff(self):
f"Sampler error setting active diff: {e}. Disabling sampler.",
style="error",
)
self.active = False
ctx.config.sampler = False

async def create_sample(self) -> Sample:
# Check for repo and merge_base in config
Expand Down
4 changes: 2 additions & 2 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
):
for file in code_context.diff_context.diff_files():
code_context.include(file)
if config.auto_save_snapshot:
if config.sampler:
sampler.set_active_diff()

def _create_task(self, coro: Coroutine[None, None, Any]):
Expand Down Expand Up @@ -185,7 +185,7 @@ async def _main(self):
await get_user_feedback_on_edits(file_edits)
)

if session_context.sampler and session_context.sampler.active:
if session_context.config.sampler:
session_context.sampler.set_active_diff()

applied_edits = await code_file_manager.write_changes_to_files(
Expand Down
2 changes: 2 additions & 0 deletions tests/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def test_sample_from_context(
mock_collect_user_input,
):
mock_session_context.config.sample_repo = "test_sample_repo"
mock_session_context.config.sampler = True

mocker.patch(
"mentat.conversation.Conversation.get_messages",
Expand Down Expand Up @@ -382,6 +383,7 @@ async def test_sampler_integration(
# Generate a sample using Mentat
python_client = PythonClient(cwd=temp_testbed, paths=["."])
await python_client.startup()
python_client.session.ctx.config.sampler = True
await python_client.call_mentat_auto_accept(dedent("""\
Make the following changes to "multifile_calculator/operations.py":
1. Add "# Inserted line 2" as the first line
Expand Down

0 comments on commit 41ea6d3

Please sign in to comment.