From f55e7d6ff2969d393419f6adde520fc86fe2e826 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Fri, 27 Sep 2024 17:05:57 +0900 Subject: [PATCH] resolve conflict --- .../manager/models/gql_models/session.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/ai/backend/manager/models/gql_models/session.py b/src/ai/backend/manager/models/gql_models/session.py index 9870e0bc2f..ceeb3cf55b 100644 --- a/src/ai/backend/manager/models/gql_models/session.py +++ b/src/ai/backend/manager/models/gql_models/session.py @@ -53,6 +53,7 @@ get_permission_ctx, ) from ..user import UserRole +from ..utils import execute_with_txn_retry from .kernel import KernelConnection, KernelNode if TYPE_CHECKING: @@ -544,7 +545,7 @@ async def mutate_and_get_payload( root: Any, info: graphene.ResolveInfo, **input, - ) -> Self: + ) -> ModifyComputeSession: graph_ctx: GraphQueryContext = info.context _, raw_session_id = cast(ResolvedGlobalID, input["id"]) session_id = SessionId(uuid.UUID(raw_session_id)) @@ -566,9 +567,18 @@ async def _update(db_session: AsyncSession) -> Optional[SessionRow]: .values(data) .returning(SessionRow) ) - result = await db_sess.execute(query) - session_row = result.fetchone() - return cls( + _stmt = ( + sa.select(SessionRow) + .from_statement(_update_stmt) + .execution_options(populate_existing=True) + ) + return await db_session.scalar(_stmt) + + async with graph_ctx.db.connect() as db_conn: + session_row = await execute_with_txn_retry(_update, graph_ctx.db.begin_session, db_conn) + if session_row is None: + raise ValueError(f"Session not found (id:{session_id})") + return ModifyComputeSession( ComputeSessionNode.from_row(graph_ctx, session_row), input.get("client_mutation_id"), )