Skip to content

Commit

Permalink
fix: Run batch execution after session starts (#2327)
Browse files Browse the repository at this point in the history
Co-authored-by: Joongi Kim <[email protected]>
  • Loading branch information
fregataa and achimnol committed Jul 13, 2024
1 parent 5981834 commit 09785db
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 14 deletions.
1 change: 1 addition & 0 deletions changes/2327.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Run batch execution after the batch session starts.
24 changes: 12 additions & 12 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,18 @@ async def execute_batch(
SessionFailureEvent(session_id, KernelLifecycleEventReason.TASK_CANCELLED, -2),
)

async def create_batch_execution_task(
self,
session_id: SessionId,
kernel_id: KernelId,
code_to_execute: str,
) -> None:
self._ongoing_exec_batch_tasks.add(
asyncio.create_task(
self.execute_batch(session_id, kernel_id, code_to_execute),
),
)

async def create_kernel(
self,
session_id: SessionId,
Expand Down Expand Up @@ -2130,18 +2142,6 @@ async def create_kernel(
),
)

if (
kernel_config["session_type"] == "batch"
and kernel_config["cluster_role"] == "main"
):
self._ongoing_exec_batch_tasks.add(
asyncio.create_task(
self.execute_batch(
session_id, kernel_id, kernel_config["startup_command"] or ""
),
),
)

# The startup command for the batch-type sessions will be executed by the manager
# upon firing of the "session_started" event.
return kernel_creation_info
Expand Down
15 changes: 15 additions & 0 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,21 @@ async def execute(
)
return result

@rpc_function
@collect_error
async def trigger_batch_execution(
self,
session_id: str,
kernel_id: str,
code: str,
) -> None:
log.info(
"rpc::trigger_batch_execution(k:{0}, s:{1}, code:{2})", kernel_id, session_id, code
)
await self.agent.create_batch_execution_task(
SessionId(UUID(session_id)), KernelId(UUID(kernel_id)), code
)

@rpc_function
@collect_error
async def start_service(
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class SessionStatus(enum.Enum):
"get_logs_from_agent": KernelExecutionFailed,
"refresh_session": KernelExecutionFailed,
"commit_session": KernelExecutionFailed,
"trigger_batch_execution": KernelExecutionFailed,
}


Expand Down
38 changes: 36 additions & 2 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from redis.asyncio import Redis
from sqlalchemy.exc import DBAPIError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import load_only, noload, selectinload
from sqlalchemy.orm import load_only, noload, selectinload, with_loader_criteria
from sqlalchemy.orm.exc import NoResultFound
from yarl import URL

Expand Down Expand Up @@ -1626,11 +1626,23 @@ async def _update_session_occupying_slots(db_session: AsyncSession) -> None:
SessionRow.name,
SessionRow.creation_id,
SessionRow.access_key,
SessionRow.session_type,
),
selectinload(
SessionRow.kernels,
).options(
load_only(
KernelRow.id,
KernelRow.agent,
KernelRow.cluster_role,
KernelRow.startup_command,
)
),
with_loader_criteria(KernelRow, KernelRow.cluster_role == DEFAULT_ROLE),
)
)
async with self.db.begin_readonly_session() as db_session:
updated_session = (await db_session.scalars(query)).first()
updated_session = cast(SessionRow, await db_session.scalar(query))

log.debug(
"Producing SessionStartedEvent({}, {})",
Expand All @@ -1648,6 +1660,9 @@ async def _update_session_occupying_slots(db_session: AsyncSession) -> None:
updated_session.access_key,
),
)

if updated_session.session_type == SessionTypes.BATCH:
await self.trigger_batch_execution(updated_session)
except Exception:
log.exception("error while executing _finalize_running")
raise
Expand Down Expand Up @@ -2676,6 +2691,9 @@ async def _restart_kernel(kernel: KernelRow) -> None:
SessionStartedEvent(session.id, session.creation_id),
)

if session.session_type == SessionTypes.BATCH:
await self.trigger_batch_execution(session)

async def execute(
self,
session: SessionRow,
Expand Down Expand Up @@ -2709,6 +2727,22 @@ async def execute(
flush_timeout,
)

async def trigger_batch_execution(
self,
session: SessionRow,
) -> None:
async with handle_session_exception(self.db, "trigger_batch_execution", session.id):
async with self.agent_cache.rpc_context(
session.main_kernel.agent,
invoke_timeout=30,
order_key=session.main_kernel.id,
) as rpc:
return await rpc.call.trigger_batch_execution(
str(session.id),
str(session.main_kernel.id),
session.main_kernel.startup_command or "",
)

async def interrupt_session(
self,
session: SessionRow,
Expand Down

0 comments on commit 09785db

Please sign in to comment.