From f16b76e81e357b2bd55da3b1dc7aa01b5eeccd26 Mon Sep 17 00:00:00 2001 From: Jihyun Kang Date: Fri, 5 Jul 2024 15:09:16 +0900 Subject: [PATCH] feat: Fetch container logs of a specific kernel (#2364) Co-authored-by: Sanghun Lee --- changes/2364.feature.md | 1 + .../backend/client/cli/session/lifecycle.py | 15 +++- src/ai/backend/client/func/session.py | 4 +- src/ai/backend/manager/api/exceptions.py | 4 ++ src/ai/backend/manager/api/session.py | 69 ++++++++++++++----- src/ai/backend/manager/models/session.py | 10 +++ src/ai/backend/manager/registry.py | 12 +++- 7 files changed, 89 insertions(+), 26 deletions(-) create mode 100644 changes/2364.feature.md diff --git a/changes/2364.feature.md b/changes/2364.feature.md new file mode 100644 index 0000000000..b4316110bf --- /dev/null +++ b/changes/2364.feature.md @@ -0,0 +1 @@ +Add support for fetching container logs of a specific kernel. diff --git a/src/ai/backend/client/cli/session/lifecycle.py b/src/ai/backend/client/cli/session/lifecycle.py index e7e5322fce..d0c3185df1 100644 --- a/src/ai/backend/client/cli/session/lifecycle.py +++ b/src/ai/backend/client/cli/session/lifecycle.py @@ -748,18 +748,27 @@ def ls(session_id, path): @session.command() @click.argument("session_id", metavar="SESSID") -def logs(session_id): +@click.option( + "-k", + "--kernel", + "--kernel-id", + type=str, + default=None, + help="The target kernel id of logs. Default value is None, in which case logs of a main kernel are fetched.", +) +def logs(session_id, kernel: str | None): """ Shows the full console log of a compute session. \b SESSID: Session ID or its alias given when creating the session. """ + _kernel_id = uuid.UUID(kernel) if kernel is not None else None with Session() as session: try: print_wait("Retrieving live container logs...") - kernel = session.ComputeSession(session_id) - result = kernel.get_logs().get("result") + _session = session.ComputeSession(session_id) + result = _session.get_logs(_kernel_id).get("result") logs = result.get("logs") if "logs" in result else "" print(logs) print_done("End of logs.") diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index 382e196d93..a5dd416fee 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -699,13 +699,15 @@ async def get_info(self): return await resp.json() @api_function - async def get_logs(self): + async def get_logs(self, kernel_id: UUID | None = None): """ Retrieves the console log of the compute session container. """ params = {} if self.owner_access_key: params["owner_access_key"] = self.owner_access_key + if kernel_id is not None: + params["kernel_id"] = str(kernel_id) prefix = get_naming(api_session.get().api_version, "path") rqst = Request( "GET", diff --git a/src/ai/backend/manager/api/exceptions.py b/src/ai/backend/manager/api/exceptions.py index 3a1166f7ba..1faf5f9af5 100644 --- a/src/ai/backend/manager/api/exceptions.py +++ b/src/ai/backend/manager/api/exceptions.py @@ -216,6 +216,10 @@ class MainKernelNotFound(ObjectNotFound): object_name = "main kernel" +class KernelNotFound(ObjectNotFound): + object_name = "kernel" + + class EndpointNotFound(ObjectNotFound): object_name = "endpoint" diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 2148ffd664..9446381372 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -43,7 +43,7 @@ import trafaret as t from aiohttp import hdrs, web from dateutil.tz import tzutc -from pydantic import BaseModel, Field +from pydantic import AliasChoices, BaseModel, Field from redis.asyncio import Redis from sqlalchemy.orm import noload, selectinload from sqlalchemy.sql.expression import null, true @@ -73,6 +73,7 @@ AgentId, ClusterMode, ImageRegistry, + KernelId, MountPermission, MountTypes, SessionTypes, @@ -2109,19 +2110,36 @@ async def list_files(request: web.Request) -> web.Response: return web.json_response(resp, status=200) +class ContainerLogRequestModel(BaseModel): + owner_access_key: str | None = Field( + validation_alias=AliasChoices("owner_access_key", "ownerAccessKey"), + default=None, + ) + kernel_id: uuid.UUID | None = Field( + validation_alias=AliasChoices("kernel_id", "kernelId"), + description="Target kernel to get container logs.", + default=None, + ) + + @server_status_required(READ_ALLOWED) @auth_required -@check_api_params( - t.Dict({ - t.Key("owner_access_key", default=None): t.Null | t.String, - }) -) -async def get_container_logs(request: web.Request, params: Any) -> web.Response: +@pydantic_params_api_handler(ContainerLogRequestModel) +async def get_container_logs( + request: web.Request, params: ContainerLogRequestModel +) -> web.Response: root_ctx: RootContext = request.app["_root.context"] session_name: str = request.match_info["session_name"] - requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + requester_access_key, owner_access_key = await get_access_key_scopes( + request, {"owner_access_key": params.owner_access_key} + ) + kernel_id = KernelId(params.kernel_id) if params.kernel_id is not None else None log.info( - "GET_CONTAINER_LOG (ak:{}/{}, s:{})", requester_access_key, owner_access_key, session_name + "GET_CONTAINER_LOG (ak:{}/{}, s:{}, k:{})", + requester_access_key, + owner_access_key, + session_name, + kernel_id, ) resp = {"result": {"logs": ""}} async with root_ctx.db.begin_readonly_session() as db_sess: @@ -2130,25 +2148,38 @@ async def get_container_logs(request: web.Request, params: Any) -> web.Response: session_name, owner_access_key, allow_stale=True, - kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY + if kernel_id is None + else KernelLoadingStrategy.ALL_KERNELS, ) - if ( - compute_session.status in DEAD_SESSION_STATUSES - and compute_session.main_kernel.container_log is not None - ): - log.debug("returning log from database record") - resp["result"]["logs"] = compute_session.main_kernel.container_log.decode("utf-8") - return web.json_response(resp, status=200) + + if compute_session.status in DEAD_SESSION_STATUSES: + if kernel_id is not None: + # Get logs from the specific kernel + kernel_row = compute_session.get_kernel_by_id(kernel_id) + kernel_log = kernel_row.container_log + else: + # Get logs from the main kernel + kernel_log = compute_session.main_kernel.container_log + if kernel_log is not None: + # Get logs from database record + log.debug("returning log from database record") + resp["result"]["logs"] = kernel_log.decode("utf-8") + return web.json_response(resp, status=200) + try: registry = root_ctx.registry await registry.increment_session_usage(compute_session) - resp["result"]["logs"] = await registry.get_logs_from_agent(compute_session) + resp["result"]["logs"] = await registry.get_logs_from_agent( + session=compute_session, kernel_id=kernel_id + ) log.debug("returning log from agent") except BackendError: log.exception( - "GET_CONTAINER_LOG(ak:{}/{}, s:{}): unexpected error", + "GET_CONTAINER_LOG(ak:{}/{}, kernel_id: {}, s:{}): unexpected error", requester_access_key, owner_access_key, + kernel_id, session_name, ) raise diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index 6eebeb4e55..52407de416 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -44,6 +44,7 @@ KernelCreationFailed, KernelDestructionFailed, KernelExecutionFailed, + KernelNotFound, KernelRestartFailed, MainKernelNotFound, SessionNotFound, @@ -79,6 +80,7 @@ from .gql import GraphQueryContext +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] __all__ = ( "determine_session_status", @@ -730,6 +732,14 @@ def resource_opts(self) -> dict[str, Any]: def is_private(self) -> bool: return any([kernel.is_private for kernel in self.kernels]) + def get_kernel_by_id(self, kernel_id: KernelId) -> KernelRow: + kerns = tuple(kern for kern in self.kernels if kern.id == kernel_id) + if len(kerns) > 1: + raise TooManyKernelsFound(f"Multiple kernels found (id:{kernel_id}).") + if len(kerns) == 0: + raise KernelNotFound(f"Session has no such kernel (sid:{self.id}, kid:{kernel_id}))") + return kerns[0] + def get_kernel_by_cluster_name(self, cluster_name: str) -> KernelRow: kerns = tuple(kern for kern in self.kernels if kern.cluster_name == cluster_name) if len(kerns) > 1: diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 3427285bf6..46566e1e70 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -2810,14 +2810,20 @@ async def list_files( async def get_logs_from_agent( self, session: SessionRow, + kernel_id: KernelId | None = None, ) -> str: async with handle_session_exception(self.db, "get_logs_from_agent", session.id): + kernel = ( + session.get_kernel_by_id(kernel_id) + if kernel_id is not None + else session.main_kernel + ) async with self.agent_cache.rpc_context( - session.main_kernel.agent, + agent_id=kernel.agent, invoke_timeout=30, - order_key=session.main_kernel.id, + order_key=kernel.id, ) as rpc: - reply = await rpc.call.get_logs(str(session.main_kernel.id)) + reply = await rpc.call.get_logs(str(kernel.id)) return reply["logs"] async def increment_session_usage(