From 41ea6d3f636e3bc9712556077661d0179a9358e2 Mon Sep 17 00:00:00 2001 From: Grant <50287275+granawkins@users.noreply.github.com> Date: Fri, 19 Jan 2024 07:01:23 +0700 Subject: [PATCH] config.sampler to turn on/off sampler (#496) --- mentat/config.py | 2 +- mentat/sampler/sampler.py | 3 +-- mentat/session.py | 4 ++-- tests/sampler_test.py | 2 ++ 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mentat/config.py b/mentat/config.py index 574611d6e..b3666ed98 100644 --- a/mentat/config.py +++ b/mentat/config.py @@ -109,7 +109,7 @@ class Config: }, converter=converters.optional(converters.to_bool), ) - auto_save_snapshot: bool = attr.field( + sampler: bool = attr.field( default=False, metadata={ "description": ( diff --git a/mentat/sampler/sampler.py b/mentat/sampler/sampler.py index c4216c35a..e5c7cd5e8 100644 --- a/mentat/sampler/sampler.py +++ b/mentat/sampler/sampler.py @@ -17,7 +17,6 @@ class Sampler: - active: bool = True diff_active: str | None = None commit_active: str | None = None last_sample_id: str | None = None @@ -43,7 +42,7 @@ def set_active_diff(self): f"Sampler error setting active diff: {e}. Disabling sampler.", style="error", ) - self.active = False + ctx.config.sampler = False async def create_sample(self) -> Sample: # Check for repo and merge_base in config diff --git a/mentat/session.py b/mentat/session.py index 5e7f66749..d9a444da8 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -117,7 +117,7 @@ def __init__( ): for file in code_context.diff_context.diff_files(): code_context.include(file) - if config.auto_save_snapshot: + if config.sampler: sampler.set_active_diff() def _create_task(self, coro: Coroutine[None, None, Any]): @@ -185,7 +185,7 @@ async def _main(self): await get_user_feedback_on_edits(file_edits) ) - if session_context.sampler and session_context.sampler.active: + if session_context.config.sampler: session_context.sampler.set_active_diff() applied_edits = await code_file_manager.write_changes_to_files( diff --git a/tests/sampler_test.py b/tests/sampler_test.py index 72f512e59..60af89d60 100644 --- a/tests/sampler_test.py +++ b/tests/sampler_test.py @@ -37,6 +37,7 @@ async def test_sample_from_context( mock_collect_user_input, ): mock_session_context.config.sample_repo = "test_sample_repo" + mock_session_context.config.sampler = True mocker.patch( "mentat.conversation.Conversation.get_messages", @@ -382,6 +383,7 @@ async def test_sampler_integration( # Generate a sample using Mentat python_client = PythonClient(cwd=temp_testbed, paths=["."]) await python_client.startup() + python_client.session.ctx.config.sampler = True await python_client.call_mentat_auto_accept(dedent("""\ Make the following changes to "multifile_calculator/operations.py": 1. Add "# Inserted line 2" as the first line