Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Run batch execution after session starts #2327

Merged
merged 13 commits into from
Jul 13, 2024
Merged
1 change: 1 addition & 0 deletions changes/2327.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Run batch execution after the batch session starts.
24 changes: 12 additions & 12 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,18 @@ async def execute_batch(
SessionFailureEvent(session_id, KernelLifecycleEventReason.TASK_CANCELLED, -2),
)

async def create_batch_execution_task(
self,
session_id: SessionId,
kernel_id: KernelId,
code_to_execute: str,
) -> None:
self._ongoing_exec_batch_tasks.add(
asyncio.create_task(
self.execute_batch(session_id, kernel_id, code_to_execute),
),
)

async def create_kernel(
self,
session_id: SessionId,
Expand Down Expand Up @@ -2124,18 +2136,6 @@ async def create_kernel(
),
)

if (
kernel_config["session_type"] == "batch"
and kernel_config["cluster_role"] == "main"
):
self._ongoing_exec_batch_tasks.add(
asyncio.create_task(
self.execute_batch(
session_id, kernel_id, kernel_config["startup_command"] or ""
),
),
)

# The startup command for the batch-type sessions will be executed by the manager
# upon firing of the "session_started" event.
return kernel_creation_info
Expand Down
15 changes: 15 additions & 0 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,21 @@ async def execute(
)
return result

@rpc_function
@collect_error
async def trigger_batch_execution(
self,
session_id: str,
kernel_id: str,
code: str,
) -> None:
log.info(
"rpc::trigger_batch_execution(k:{0}, s:{1}, code:{2})", kernel_id, session_id, code
)
await self.agent.create_batch_execution_task(
SessionId(UUID(session_id)), KernelId(UUID(kernel_id)), code
)

@rpc_function
@collect_error
async def start_service(
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class SessionStatus(enum.Enum):
"get_logs_from_agent": KernelExecutionFailed,
"refresh_session": KernelExecutionFailed,
"commit_session": KernelExecutionFailed,
"trigger_batch_execution": KernelExecutionFailed,
}


Expand Down
35 changes: 33 additions & 2 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from redis.asyncio import Redis
from sqlalchemy.exc import DBAPIError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import load_only, noload, selectinload
from sqlalchemy.orm import load_only, noload, selectinload, with_loader_criteria
from sqlalchemy.orm.exc import NoResultFound
from yarl import URL

Expand Down Expand Up @@ -1626,11 +1626,23 @@ async def _update_session_occupying_slots(db_session: AsyncSession) -> None:
SessionRow.name,
SessionRow.creation_id,
SessionRow.access_key,
SessionRow.session_type,
),
selectinload(
SessionRow.kernels,
).options(
load_only(
KernelRow.id,
KernelRow.agent,
KernelRow.cluster_role,
KernelRow.startup_command,
)
),
with_loader_criteria(KernelRow, KernelRow.cluster_role == DEFAULT_ROLE),
)
)
async with self.db.begin_readonly_session() as db_session:
updated_session = (await db_session.scalars(query)).first()
updated_session = cast(SessionRow, await db_session.scalar(query))

log.debug(
"Producing SessionStartedEvent({}, {})",
Expand All @@ -1648,6 +1660,9 @@ async def _update_session_occupying_slots(db_session: AsyncSession) -> None:
updated_session.access_key,
),
)

if updated_session.session_type == SessionTypes.BATCH:
await self.trigger_batch_execution(updated_session)
except Exception:
log.exception("error while executing _finalize_running")
raise
Expand Down Expand Up @@ -2709,6 +2724,22 @@ async def execute(
flush_timeout,
)

async def trigger_batch_execution(
self,
session: SessionRow,
) -> None:
async with handle_session_exception(self.db, "trigger_batch_execution", session.id):
async with self.agent_cache.rpc_context(
session.main_kernel.agent,
invoke_timeout=30,
order_key=session.main_kernel.id,
) as rpc:
return await rpc.call.trigger_batch_execution(
str(session.id),
str(session.main_kernel.id),
session.main_kernel.startup_command or "",
)

async def interrupt_session(
self,
session: SessionRow,
Expand Down
Loading