From 21923fd9f338f4fa7e80459e831fa4276780627f Mon Sep 17 00:00:00 2001 From: octodog Date: Sun, 14 Jul 2024 00:14:49 +0900 Subject: [PATCH] fix: Run batch execution after session starts (#2327) (#2440) Co-authored-by: Sanghun Lee Co-authored-by: Joongi Kim --- changes/2327.fix.md | 1 + src/ai/backend/agent/agent.py | 24 +++++++-------- src/ai/backend/agent/server.py | 15 ++++++++++ src/ai/backend/manager/models/session.py | 1 + src/ai/backend/manager/registry.py | 38 ++++++++++++++++++++++-- 5 files changed, 65 insertions(+), 14 deletions(-) create mode 100644 changes/2327.fix.md diff --git a/changes/2327.fix.md b/changes/2327.fix.md new file mode 100644 index 0000000000..84b0e426d9 --- /dev/null +++ b/changes/2327.fix.md @@ -0,0 +1 @@ +Run batch execution after the batch session starts. diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 4863462528..dbd922a984 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -1740,6 +1740,18 @@ async def execute_batch( SessionFailureEvent(session_id, KernelLifecycleEventReason.TASK_CANCELLED, -2), ) + async def create_batch_execution_task( + self, + session_id: SessionId, + kernel_id: KernelId, + code_to_execute: str, + ) -> None: + self._ongoing_exec_batch_tasks.add( + asyncio.create_task( + self.execute_batch(session_id, kernel_id, code_to_execute), + ), + ) + async def create_kernel( self, session_id: SessionId, @@ -2130,18 +2142,6 @@ async def create_kernel( ), ) - if ( - kernel_config["session_type"] == "batch" - and kernel_config["cluster_role"] == "main" - ): - self._ongoing_exec_batch_tasks.add( - asyncio.create_task( - self.execute_batch( - session_id, kernel_id, kernel_config["startup_command"] or "" - ), - ), - ) - # The startup command for the batch-type sessions will be executed by the manager # upon firing of the "session_started" event. return kernel_creation_info diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 659f27f2b7..b61da3af54 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -626,6 +626,21 @@ async def execute( ) return result + @rpc_function + @collect_error + async def trigger_batch_execution( + self, + session_id: str, + kernel_id: str, + code: str, + ) -> None: + log.info( + "rpc::trigger_batch_execution(k:{0}, s:{1}, code:{2})", kernel_id, session_id, code + ) + await self.agent.create_batch_execution_task( + SessionId(UUID(session_id)), KernelId(UUID(kernel_id)), code + ) + @rpc_function @collect_error async def start_service( diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index 6ae33cb501..6a8d408a98 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -177,6 +177,7 @@ class SessionStatus(enum.Enum): "get_logs_from_agent": KernelExecutionFailed, "refresh_session": KernelExecutionFailed, "commit_session": KernelExecutionFailed, + "trigger_batch_execution": KernelExecutionFailed, } diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 91c0fcdcce..f5c35f0136 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -46,7 +46,7 @@ from redis.asyncio import Redis from sqlalchemy.exc import DBAPIError from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import load_only, noload, selectinload +from sqlalchemy.orm import load_only, noload, selectinload, with_loader_criteria from sqlalchemy.orm.exc import NoResultFound from yarl import URL @@ -1627,11 +1627,23 @@ async def _update_session_occupying_slots() -> None: SessionRow.name, SessionRow.creation_id, SessionRow.access_key, + SessionRow.session_type, ), + selectinload( + SessionRow.kernels, + ).options( + load_only( + KernelRow.id, + KernelRow.agent, + KernelRow.cluster_role, + KernelRow.startup_command, + ) + ), + with_loader_criteria(KernelRow, KernelRow.cluster_role == DEFAULT_ROLE), ) ) async with self.db.begin_readonly_session() as db_session: - updated_session = (await db_session.scalars(query)).first() + updated_session = cast(SessionRow, await db_session.scalar(query)) log.debug( "Producing SessionStartedEvent({}, {})", @@ -1649,6 +1661,9 @@ async def _update_session_occupying_slots() -> None: updated_session.access_key, ), ) + + if updated_session.session_type == SessionTypes.BATCH: + await self.trigger_batch_execution(updated_session) except Exception: log.exception("error while executing _finalize_running") raise @@ -2677,6 +2692,9 @@ async def _restart_kernel(kernel: KernelRow) -> None: SessionStartedEvent(session.id, session.creation_id), ) + if session.session_type == SessionTypes.BATCH: + await self.trigger_batch_execution(session) + async def execute( self, session: SessionRow, @@ -2710,6 +2728,22 @@ async def execute( flush_timeout, ) + async def trigger_batch_execution( + self, + session: SessionRow, + ) -> None: + async with handle_session_exception(self.db, "trigger_batch_execution", session.id): + async with self.agent_cache.rpc_context( + session.main_kernel.agent, + invoke_timeout=30, + order_key=session.main_kernel.id, + ) as rpc: + return await rpc.call.trigger_batch_execution( + str(session.id), + str(session.main_kernel.id), + session.main_kernel.startup_command or "", + ) + async def interrupt_session( self, session: SessionRow,