From 8b810752e69de143d4deb98aa72cc8ce1c18faf0 Mon Sep 17 00:00:00 2001 From: Dan Fuchs Date: Wed, 29 May 2024 15:12:34 -0500 Subject: [PATCH] Notebook runner refreshes by recloning and reexecuting --- src/mobu/services/business/notebookrunner.py | 23 +++++- src/monkeyflocker/client.py | 4 +- tests/business/notebookrunner_test.py | 84 +++++++++++++++++++- tests/handlers/flock_test.py | 4 +- tests/support/jupyter.py | 6 +- tests/support/util.py | 18 ++++- 6 files changed, 128 insertions(+), 11 deletions(-) diff --git a/src/mobu/services/business/notebookrunner.py b/src/mobu/services/business/notebookrunner.py index 3e7cb540..cfb24287 100644 --- a/src/mobu/services/business/notebookrunner.py +++ b/src/mobu/services/business/notebookrunner.py @@ -70,21 +70,34 @@ def annotations(self, cell_id: str | None = None) -> dict[str, str]: return result async def startup(self) -> None: + await self.initialize() + await super().startup() + + async def cleanup(self) -> None: + shutil.rmtree(str(self._repo_dir)) + self._repo_dir = None + + async def initialize(self) -> None: if self._repo_dir is None: self._repo_dir = Path(TemporaryDirectory().name) await self.clone_repo() + self._exclude_paths = { (self._repo_dir / path) for path in self.options.exclude_dirs } self._notebook_paths = self.find_notebooks() self.logger.info("Repository cloned and ready") - await super().startup() async def shutdown(self) -> None: - shutil.rmtree(str(self._repo_dir)) - self._repo_dir = None + await self.cleanup() await super().shutdown() + async def refresh(self) -> None: + self.logger.info("Recloning notebooks and forcing new execution") + await self.cleanup() + await self.initialize() + self.refreshing = False + async def clone_repo(self) -> None: url = self.options.repo_url branch = self.options.repo_branch @@ -151,6 +164,10 @@ async def open_session( async def execute_code(self, session: JupyterLabSession) -> None: for count in range(self.options.max_executions): + if self.refreshing: + await self.refresh() + return + self._notebook = self.next_notebook() iteration = f"{count + 1}/{self.options.max_executions}" diff --git a/src/monkeyflocker/client.py b/src/monkeyflocker/client.py index 3a752a04..5af68cd0 100644 --- a/src/monkeyflocker/client.py +++ b/src/monkeyflocker/client.py @@ -100,8 +100,8 @@ async def refresh(self, name: str) -> None: """Restart a flock of monkeys.""" if not self._client: raise RuntimeError("Must be used as a context manager") - url = urljoin(self._base_url, f"/mobu/flocks/{name}") - r = await self._client.put(url) + url = urljoin(self._base_url, f"/mobu/flocks/{name}/refresh") + r = await self._client.post(url) r.raise_for_status() def _initialize_logging(self) -> BoundLogger: diff --git a/tests/business/notebookrunner_test.py b/tests/business/notebookrunner_test.py index fd0d93e3..61c4ba28 100644 --- a/tests/business/notebookrunner_test.py +++ b/tests/business/notebookrunner_test.py @@ -15,7 +15,8 @@ from mobu.storage.git import Git from ..support.gafaelfawr import mock_gafaelfawr -from ..support.util import wait_for_business +from ..support.jupyter import MockJupyter +from ..support.util import wait_for_business, wait_for_log_message async def setup_git_repo(repo_path: Path) -> None: @@ -207,6 +208,87 @@ async def test_run_recursive( assert "Done with this cycle of notebooks" in r.text +@pytest.mark.asyncio +async def test_refresh( + client: AsyncClient, + jupyter: MockJupyter, + respx_mock: respx.Router, + tmp_path: Path, +) -> None: + mock_gafaelfawr(respx_mock) + cwd = Path.cwd() + + # Set up a notebook repository. + source_path = Path(__file__).parent.parent / "notebooks" + repo_path = tmp_path / "notebooks" + + shutil.copytree(str(source_path), str(repo_path)) + + # Set up git repo + await setup_git_repo(repo_path) + + # Start a monkey. We have to do this in a try/finally block since the + # runner will change working directories, which because working + # directories are process-global may mess up future tests. + try: + r = await client.put( + "/mobu/flocks", + json={ + "name": "test", + "count": 1, + "user_spec": {"username_prefix": "testuser"}, + "scopes": ["exec:notebook"], + "business": { + "type": "NotebookRunner", + "options": { + "spawn_settle_time": 0, + "execution_idle_time": 1, + "idle_time": 1, + "max_executions": 1000, + "repo_url": str(repo_path), + "repo_branch": "main", + "working_directory": str(repo_path), + }, + }, + }, + ) + assert r.status_code == 201 + + # We should see a message from the notebook execution in the logs. + assert await wait_for_log_message( + client, "testuser1", msg="This is a test" + ) + + # Change the notebook and git commit it + notebook = repo_path / "test-notebook.ipynb" + contents = notebook.read_text() + new_contents = contents.replace("This is a test", "This is a NEW test") + notebook.write_text(new_contents) + + git = Git(repo=repo_path) + await git.add(str(notebook)) + await git.commit("-m", "Updating notebook") + + jupyter.expected_session_name = "test-notebook.ipynb" + jupyter.expected_session_type = "notebook" + + # Refresh the notebook + r = await client.post("/mobu/flocks/test/refresh") + assert r.status_code == 202 + + # The refresh should have forced a new execution + assert await wait_for_log_message( + client, "testuser1", msg="Deleting lab" + ) + + # We should see a message from the updated notebook. + assert await wait_for_log_message( + client, "testuser1", msg="This is a NEW test" + ) + finally: + os.chdir(cwd) + + @pytest.mark.asyncio async def test_exclude_dirs( client: AsyncClient, respx_mock: respx.Router, tmp_path: Path diff --git a/tests/handlers/flock_test.py b/tests/handlers/flock_test.py index 72b305a7..04014549 100644 --- a/tests/handlers/flock_test.py +++ b/tests/handlers/flock_test.py @@ -76,7 +76,7 @@ async def test_start_stop_refresh( assert r.status_code == 200 assert r.json() == expected - r = await client.put("/mobu/flocks/test") + r = await client.post("/mobu/flocks/test/refresh") assert r.status_code == 202 # That should've updated the refreshing status expected["monkeys"][0]["business"]["refreshing"] = True @@ -114,7 +114,7 @@ async def test_start_stop_refresh( r = await client.get("/mobu/flocks/other") assert r.status_code == 404 - r = await client.put("/mobu/flocks/other") + r = await client.post("/mobu/flocks/other/refresh") assert r.status_code == 404 r = await client.delete("/mobu/flocks/other") assert r.status_code == 404 diff --git a/tests/support/jupyter.py b/tests/support/jupyter.py index be762d5e..6e6c30a5 100644 --- a/tests/support/jupyter.py +++ b/tests/support/jupyter.py @@ -91,6 +91,8 @@ def __init__(self) -> None: self.spawn_timeout = False self.redirect_loop = False self.lab_form: dict[str, dict[str, str]] = {} + self.expected_session_name = "(no notebook)" + self.expected_session_type = "console" self._delete_at: dict[str, datetime | None] = {} self._fail: dict[str, dict[JupyterAction, bool]] = {} self._hub_xsrf = os.urandom(8).hex() @@ -278,8 +280,8 @@ def create_session(self, request: Request) -> Response: assert state == JupyterState.LAB_RUNNING body = json.loads(request.content.decode()) assert body["kernel"]["name"] == "LSST" - assert body["name"] == "(no notebook)" - assert body["type"] == "console" + assert body["name"] == self.expected_session_name + assert body["type"] == self.expected_session_type session = JupyterLabSession( session_id=uuid4().hex, kernel_id=uuid4().hex ) diff --git a/tests/support/util.py b/tests/support/util.py index 313ba2d8..4c4f0729 100644 --- a/tests/support/util.py +++ b/tests/support/util.py @@ -7,7 +7,10 @@ from httpx import AsyncClient -__all__ = ["wait_for_business"] +__all__ = [ + "wait_for_business", + "wait_for_log_message", +] async def wait_for_business( @@ -26,6 +29,19 @@ async def wait_for_business( return data +async def wait_for_log_message( + client: AsyncClient, username: str, *, flock: str = "test", msg: str +) -> bool: + """Wait until some text appears in a user's log.""" + for _ in range(1, 10): + await asyncio.sleep(0.5) + r = await client.get(f"/mobu/flocks/{flock}/monkeys/{username}/log") + assert r.status_code == 200 + if msg in r.text: + return True + return False + + async def wait_for_flock_start(client: AsyncClient, flock: str) -> None: """Wait for all the monkeys in a flock to have started.""" for _ in range(1, 10):