From 59be0d4a48278d6c2ae3de29a5f4039f7b4df3f4 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Fri, 27 Sep 2024 14:55:37 +0900 Subject: [PATCH] Use ORM object to query and update session row --- .../manager/models/gql_models/session.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/ai/backend/manager/models/gql_models/session.py b/src/ai/backend/manager/models/gql_models/session.py index cd637ed408..4aa6253f61 100644 --- a/src/ai/backend/manager/models/gql_models/session.py +++ b/src/ai/backend/manager/models/gql_models/session.py @@ -16,6 +16,7 @@ import sqlalchemy as sa from dateutil.parser import parse as dtparse from graphene.types.datetime import DateTime as GQLDateTime +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from ai.backend.common.types import ClusterMode, SessionId, SessionResult @@ -52,6 +53,7 @@ get_permission_ctx, ) from ..user import UserRole +from ..utils import execute_with_txn_retry from .kernel import KernelConnection, KernelNode if TYPE_CHECKING: @@ -553,18 +555,21 @@ async def mutate_and_get_payload( f"The priority value {input["priority"]!r} is out of range: " f"[{SESSION_PRIORITY_MIN}, {SESSION_PRIORITY_MAX}]." ) - async with graph_ctx.db.begin_session() as db_sess: - data: dict[str, Any] = {} - set_if_set(input, data, "name") - set_if_set(input, data, "priority") - query = ( - sa.update(SessionRow) - .where(SessionRow.id == session_id) - .values(data) - .returning(SessionRow) - ) - result = await db_sess.execute(query) - session_row = result.fetchone() + data: dict[str, Any] = {} + set_if_set(input, data, "name") + set_if_set(input, data, "priority") + + async def _update(db_sess: AsyncSession) -> SessionRow: + query = sa.select(SessionRow).where(SessionRow.id == session_id) + session_row = cast(SessionRow, await db_sess.scalar(query)) + if "name" in input: + session_row.name = input["name"] + if "priority" in input: + session_row.priority = input["priority"] + return session_row + + async with graph_ctx.db.connect() as db_conn: + session_row = await execute_with_txn_retry(_update, graph_ctx.db.begin_session, db_conn) return cls( ComputeSessionNode.from_row(graph_ctx, session_row), input.get("client_mutation_id"),